看起来尺寸大了以后效果可能会有差异。

This commit is contained in:
long0x0 2025-03-28 23:29:41 +08:00
parent 4774d3ef39
commit 374cd36597

View File

@ -27,6 +27,11 @@ configs = [
num_warps=4,
num_stages=3,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
num_warps=16,
num_stages=5,
),
# 添加更多参数组合...
]
@ -121,7 +126,7 @@ def benchmark(fn, *args, **kwargs):
def benchmark_triton():
# 假设你的 Triton 核函数名为 matmul_kernel调用方法为 triton_matmul
M, N, K = 4096, 4096, 4096 # 矩阵尺寸
M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
@ -132,7 +137,7 @@ def benchmark_triton():
def benchmark_torch():
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
M, N, K = 4096, 4096, 4096 # 矩阵尺寸
M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)