看起来尺寸大了以后效果可能会有差异。
This commit is contained in:
parent
4774d3ef39
commit
374cd36597
@ -27,6 +27,11 @@ configs = [
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
|
||||
num_warps=16,
|
||||
num_stages=5,
|
||||
),
|
||||
# 添加更多参数组合...
|
||||
]
|
||||
|
||||
@ -121,7 +126,7 @@ def benchmark(fn, *args, **kwargs):
|
||||
|
||||
def benchmark_triton():
|
||||
# 假设你的 Triton 核函数名为 matmul_kernel,调用方法为 triton_matmul
|
||||
M, N, K = 4096, 4096, 4096 # 矩阵尺寸
|
||||
M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸
|
||||
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
||||
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
@ -132,7 +137,7 @@ def benchmark_triton():
|
||||
|
||||
def benchmark_torch():
|
||||
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
|
||||
M, N, K = 4096, 4096, 4096 # 矩阵尺寸
|
||||
M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸
|
||||
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
||||
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user