加上一些点优化的脚本。
This commit is contained in:
parent
baaa5dbc1c
commit
5ac163f95c
@ -8,12 +8,30 @@ import triton.language as tl
|
||||
# the dtype must be float32
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": M,
|
||||
"BLOCK_N": N,
|
||||
"BLOCK_K": K,
|
||||
},
|
||||
num_stages=s,
|
||||
num_warps=nw,
|
||||
)
|
||||
for M in [64, 128]
|
||||
for N in [64, 128]
|
||||
for K in [32, 64]
|
||||
for s in [2, 4]
|
||||
for nw in [4, 8]
|
||||
],
|
||||
key=["M", "N", "K"],
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
BS,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@ -101,7 +119,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
c_ptr=c,
|
||||
BS=a.size(0),
|
||||
M=M,
|
||||
N=N,
|
||||
K=K,
|
||||
@ -114,9 +131,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
|
||||
stride_cb=c.stride(0),
|
||||
stride_cm=c.stride(1),
|
||||
stride_cn=c.stride(2),
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_K=BLOCK_K,
|
||||
)
|
||||
|
||||
return c
|
||||
@ -125,8 +139,17 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
|
||||
# 测试
|
||||
if __name__ == "__main__":
|
||||
for i in range(100):
|
||||
a = torch.randn(size=(128, 512, 256), device="cuda", dtype=torch.float16)
|
||||
b = torch.randn(size=(128, 256, 384), device="cuda", dtype=torch.float16)
|
||||
c_triton = matmul(a, b)
|
||||
c_torch = torch.matmul(a, b)
|
||||
print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}")
|
||||
for M in [1024, 2048, 4096]:
|
||||
for N in [1024, 2048, 4096]:
|
||||
for K in [1024, 2048, 4096]:
|
||||
a = torch.randn(
|
||||
size=(128, M, K), device="cuda", dtype=torch.float16
|
||||
)
|
||||
b = torch.randn(
|
||||
size=(128, K, N), device="cuda", dtype=torch.float16
|
||||
)
|
||||
c_triton = matmul(a, b)
|
||||
c_torch = torch.matmul(a, b)
|
||||
print(
|
||||
f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-4)}"
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user