diff --git a/tests/test_block_md.py b/tests/test_block_md.py index fda816d..9cc4f8b 100644 --- a/tests/test_block_md.py +++ b/tests/test_block_md.py @@ -8,12 +8,30 @@ import triton.language as tl # the dtype must be float32 +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": M, + "BLOCK_N": N, + "BLOCK_K": K, + }, + num_stages=s, + num_warps=nw, + ) + for M in [64, 128] + for N in [64, 128] + for K in [32, 64] + for s in [2, 4] + for nw in [4, 8] + ], + key=["M", "N", "K"], +) @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, - BS, M, N, K, @@ -101,7 +119,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor): a_ptr=a, b_ptr=b, c_ptr=c, - BS=a.size(0), M=M, N=N, K=K, @@ -114,9 +131,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor): stride_cb=c.stride(0), stride_cm=c.stride(1), stride_cn=c.stride(2), - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_K=BLOCK_K, ) return c @@ -125,8 +139,17 @@ def matmul(a: torch.Tensor, b: torch.Tensor): # 测试 if __name__ == "__main__": for i in range(100): - a = torch.randn(size=(128, 512, 256), device="cuda", dtype=torch.float16) - b = torch.randn(size=(128, 256, 384), device="cuda", dtype=torch.float16) - c_triton = matmul(a, b) - c_torch = torch.matmul(a, b) - print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}") + for M in [1024, 2048, 4096]: + for N in [1024, 2048, 4096]: + for K in [1024, 2048, 4096]: + a = torch.randn( + size=(128, M, K), device="cuda", dtype=torch.float16 + ) + b = torch.randn( + size=(128, K, N), device="cuda", dtype=torch.float16 + ) + c_triton = matmul(a, b) + c_torch = torch.matmul(a, b) + print( + f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-4)}" + )