From 4be98aed30a0124654dda48cfe3218821597f5d3 Mon Sep 17 00:00:00 2001 From: longfei li Date: Sat, 29 Mar 2025 16:37:38 +0800 Subject: [PATCH] seems like merge to kernel together is faster. --- tests/test_triton.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_triton.py b/tests/test_triton.py index 2ea0f34..98b8180 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -5,6 +5,14 @@ import torch from loguru import logger logger.add("triton.log", rotation="1 MB", retention="10 days", level="DEBUG") +autotune_results = [] + + +def log_autotune(config, timing): + autotune_results.append((config, timing)) + logger.info(f"Tested: {config} → {timing * 1e6:.2f} us") + + configs = [ # 针对小规模矩阵的配置(例如 M,N,K < 1024) triton.Config( @@ -83,6 +91,7 @@ def matmul_kernel( b_val = tl.load(b_ptr, mask=b_mask, other=0.0) acc += tl.dot(a_val, b_val) acc.to(tl.bfloat16) + acc = tl.sigmoid(acc) c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn) c_mask = (block_m[:, None] < m) & (block_n[None, :] < n) tl.store(c_ptr, acc, mask=c_mask) @@ -146,13 +155,19 @@ def benchmark_triton(M, N, K): return time_ms +def torch_matmul_sigmoid(a, b): + c = torch.matmul(a, b) + c = torch.sigmoid(c) + return c + + def benchmark_torch(M, N, K): # 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) # 测速 - time_ms = benchmark(torch.matmul, a, b) + time_ms = benchmark(torch_matmul_sigmoid, a, b) print(f"PyTorch 耗时: {time_ms:.3f} ms") return time_ms