seems like merge to kernel together is faster.

This commit is contained in:
longfei li 2025-03-29 16:37:38 +08:00
parent 93b10bb894
commit 4be98aed30

View File

@ -5,6 +5,14 @@ import torch
from loguru import logger
logger.add("triton.log", rotation="1 MB", retention="10 days", level="DEBUG")
autotune_results = []
def log_autotune(config, timing):
autotune_results.append((config, timing))
logger.info(f"Tested: {config}{timing * 1e6:.2f} us")
configs = [
# 针对小规模矩阵的配置(例如 M,N,K < 1024
triton.Config(
@ -83,6 +91,7 @@ def matmul_kernel(
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
acc += tl.dot(a_val, b_val)
acc.to(tl.bfloat16)
acc = tl.sigmoid(acc)
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
tl.store(c_ptr, acc, mask=c_mask)
@ -146,13 +155,19 @@ def benchmark_triton(M, N, K):
return time_ms
def torch_matmul_sigmoid(a, b):
c = torch.matmul(a, b)
c = torch.sigmoid(c)
return c
def benchmark_torch(M, N, K):
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# 测速
time_ms = benchmark(torch.matmul, a, b)
time_ms = benchmark(torch_matmul_sigmoid, a, b)
print(f"PyTorch 耗时: {time_ms:.3f} ms")
return time_ms