加上一些点优化的脚本。

This commit is contained in:
longfei li 2025-04-12 14:21:52 +08:00
parent baaa5dbc1c
commit 5ac163f95c

View File

@ -8,12 +8,30 @@ import triton.language as tl
# the dtype must be float32
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_M": M,
"BLOCK_N": N,
"BLOCK_K": K,
},
num_stages=s,
num_warps=nw,
)
for M in [64, 128]
for N in [64, 128]
for K in [32, 64]
for s in [2, 4]
for nw in [4, 8]
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
BS,
M,
N,
K,
@ -101,7 +119,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
a_ptr=a,
b_ptr=b,
c_ptr=c,
BS=a.size(0),
M=M,
N=N,
K=K,
@ -114,9 +131,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
stride_cb=c.stride(0),
stride_cm=c.stride(1),
stride_cn=c.stride(2),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return c
@ -125,8 +139,17 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
# 测试
if __name__ == "__main__":
for i in range(100):
a = torch.randn(size=(128, 512, 256), device="cuda", dtype=torch.float16)
b = torch.randn(size=(128, 256, 384), device="cuda", dtype=torch.float16)
c_triton = matmul(a, b)
c_torch = torch.matmul(a, b)
print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}")
for M in [1024, 2048, 4096]:
for N in [1024, 2048, 4096]:
for K in [1024, 2048, 4096]:
a = torch.randn(
size=(128, M, K), device="cuda", dtype=torch.float16
)
b = torch.randn(
size=(128, K, N), device="cuda", dtype=torch.float16
)
c_triton = matmul(a, b)
c_torch = torch.matmul(a, b)
print(
f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-4)}"
)