看起来尺寸大了以后效果可能会有差异。

This commit is contained in:
long0x0 2025-03-28 23:29:41 +08:00
parent 4774d3ef39
commit 374cd36597

View File

@ -27,6 +27,11 @@ configs = [
num_warps=4, num_warps=4,
num_stages=3, num_stages=3,
), ),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
num_warps=16,
num_stages=5,
),
# 添加更多参数组合... # 添加更多参数组合...
] ]
@ -121,7 +126,7 @@ def benchmark(fn, *args, **kwargs):
def benchmark_triton(): def benchmark_triton():
# 假设你的 Triton 核函数名为 matmul_kernel调用方法为 triton_matmul # 假设你的 Triton 核函数名为 matmul_kernel调用方法为 triton_matmul
M, N, K = 4096, 4096, 4096 # 矩阵尺寸 M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
@ -132,7 +137,7 @@ def benchmark_triton():
def benchmark_torch(): def benchmark_torch():
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul # 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
M, N, K = 4096, 4096, 4096 # 矩阵尺寸 M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)