From 374cd365979d81b76398fccc2216c5fa366df765 Mon Sep 17 00:00:00 2001 From: long0x0 Date: Fri, 28 Mar 2025 23:29:41 +0800 Subject: [PATCH] =?UTF-8?q?=E7=9C=8B=E8=B5=B7=E6=9D=A5=E5=B0=BA=E5=AF=B8?= =?UTF-8?q?=E5=A4=A7=E4=BA=86=E4=BB=A5=E5=90=8E=E6=95=88=E6=9E=9C=E5=8F=AF?= =?UTF-8?q?=E8=83=BD=E4=BC=9A=E6=9C=89=E5=B7=AE=E5=BC=82=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test_triton.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test_triton.py b/test_triton.py index 767d803..bc3621a 100644 --- a/test_triton.py +++ b/test_triton.py @@ -27,6 +27,11 @@ configs = [ num_warps=4, num_stages=3, ), + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, + num_warps=16, + num_stages=5, + ), # 添加更多参数组合... ] @@ -121,7 +126,7 @@ def benchmark(fn, *args, **kwargs): def benchmark_triton(): # 假设你的 Triton 核函数名为 matmul_kernel,调用方法为 triton_matmul - M, N, K = 4096, 4096, 4096 # 矩阵尺寸 + M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸 a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) @@ -132,7 +137,7 @@ def benchmark_triton(): def benchmark_torch(): # 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul - M, N, K = 4096, 4096, 4096 # 矩阵尺寸 + M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸 a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)