加上一些点优化的脚本。
This commit is contained in:
parent
baaa5dbc1c
commit
5ac163f95c
@ -8,12 +8,30 @@ import triton.language as tl
|
|||||||
# the dtype must be float32
|
# the dtype must be float32
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": M,
|
||||||
|
"BLOCK_N": N,
|
||||||
|
"BLOCK_K": K,
|
||||||
|
},
|
||||||
|
num_stages=s,
|
||||||
|
num_warps=nw,
|
||||||
|
)
|
||||||
|
for M in [64, 128]
|
||||||
|
for N in [64, 128]
|
||||||
|
for K in [32, 64]
|
||||||
|
for s in [2, 4]
|
||||||
|
for nw in [4, 8]
|
||||||
|
],
|
||||||
|
key=["M", "N", "K"],
|
||||||
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def matmul_kernel(
|
def matmul_kernel(
|
||||||
a_ptr,
|
a_ptr,
|
||||||
b_ptr,
|
b_ptr,
|
||||||
c_ptr,
|
c_ptr,
|
||||||
BS,
|
|
||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
@ -101,7 +119,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
|
|||||||
a_ptr=a,
|
a_ptr=a,
|
||||||
b_ptr=b,
|
b_ptr=b,
|
||||||
c_ptr=c,
|
c_ptr=c,
|
||||||
BS=a.size(0),
|
|
||||||
M=M,
|
M=M,
|
||||||
N=N,
|
N=N,
|
||||||
K=K,
|
K=K,
|
||||||
@ -114,9 +131,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
|
|||||||
stride_cb=c.stride(0),
|
stride_cb=c.stride(0),
|
||||||
stride_cm=c.stride(1),
|
stride_cm=c.stride(1),
|
||||||
stride_cn=c.stride(2),
|
stride_cn=c.stride(2),
|
||||||
BLOCK_M=BLOCK_M,
|
|
||||||
BLOCK_N=BLOCK_N,
|
|
||||||
BLOCK_K=BLOCK_K,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
@ -125,8 +139,17 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
|
|||||||
# 测试
|
# 测试
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
a = torch.randn(size=(128, 512, 256), device="cuda", dtype=torch.float16)
|
for M in [1024, 2048, 4096]:
|
||||||
b = torch.randn(size=(128, 256, 384), device="cuda", dtype=torch.float16)
|
for N in [1024, 2048, 4096]:
|
||||||
c_triton = matmul(a, b)
|
for K in [1024, 2048, 4096]:
|
||||||
c_torch = torch.matmul(a, b)
|
a = torch.randn(
|
||||||
print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}")
|
size=(128, M, K), device="cuda", dtype=torch.float16
|
||||||
|
)
|
||||||
|
b = torch.randn(
|
||||||
|
size=(128, K, N), device="cuda", dtype=torch.float16
|
||||||
|
)
|
||||||
|
c_triton = matmul(a, b)
|
||||||
|
c_torch = torch.matmul(a, b)
|
||||||
|
print(
|
||||||
|
f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-4)}"
|
||||||
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user