From 5ac163f95c69f5782d3127b96b6d0f244230e7d6 Mon Sep 17 00:00:00 2001 From: longfei li Date: Sat, 12 Apr 2025 14:21:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E4=B8=8A=E4=B8=80=E4=BA=9B=E7=82=B9?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E7=9A=84=E8=84=9A=E6=9C=AC=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_block_md.py | 43 ++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/tests/test_block_md.py b/tests/test_block_md.py index fda816d..9cc4f8b 100644 --- a/tests/test_block_md.py +++ b/tests/test_block_md.py @@ -8,12 +8,30 @@ import triton.language as tl # the dtype must be float32 +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": M, + "BLOCK_N": N, + "BLOCK_K": K, + }, + num_stages=s, + num_warps=nw, + ) + for M in [64, 128] + for N in [64, 128] + for K in [32, 64] + for s in [2, 4] + for nw in [4, 8] + ], + key=["M", "N", "K"], +) @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, - BS, M, N, K, @@ -101,7 +119,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor): a_ptr=a, b_ptr=b, c_ptr=c, - BS=a.size(0), M=M, N=N, K=K, @@ -114,9 +131,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor): stride_cb=c.stride(0), stride_cm=c.stride(1), stride_cn=c.stride(2), - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_K=BLOCK_K, ) return c @@ -125,8 +139,17 @@ def matmul(a: torch.Tensor, b: torch.Tensor): # 测试 if __name__ == "__main__": for i in range(100): - a = torch.randn(size=(128, 512, 256), device="cuda", dtype=torch.float16) - b = torch.randn(size=(128, 256, 384), device="cuda", dtype=torch.float16) - c_triton = matmul(a, b) - c_torch = torch.matmul(a, b) - print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}") + for M in [1024, 2048, 4096]: + for N in [1024, 2048, 4096]: + for K in [1024, 2048, 4096]: + a = torch.randn( + size=(128, M, K), device="cuda", dtype=torch.float16 + ) + b = torch.randn( + size=(128, K, N), device="cuda", dtype=torch.float16 + ) + c_triton = matmul(a, b) + c_torch = torch.matmul(a, b) + print( + f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-4)}" + )