seems like merge to kernel together is faster.
This commit is contained in:
parent
93b10bb894
commit
4be98aed30
@ -5,6 +5,14 @@ import torch
|
||||
from loguru import logger
|
||||
|
||||
logger.add("triton.log", rotation="1 MB", retention="10 days", level="DEBUG")
|
||||
autotune_results = []
|
||||
|
||||
|
||||
def log_autotune(config, timing):
|
||||
autotune_results.append((config, timing))
|
||||
logger.info(f"Tested: {config} → {timing * 1e6:.2f} us")
|
||||
|
||||
|
||||
configs = [
|
||||
# 针对小规模矩阵的配置(例如 M,N,K < 1024)
|
||||
triton.Config(
|
||||
@ -83,6 +91,7 @@ def matmul_kernel(
|
||||
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
|
||||
acc += tl.dot(a_val, b_val)
|
||||
acc.to(tl.bfloat16)
|
||||
acc = tl.sigmoid(acc)
|
||||
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
|
||||
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
|
||||
tl.store(c_ptr, acc, mask=c_mask)
|
||||
@ -146,13 +155,19 @@ def benchmark_triton(M, N, K):
|
||||
return time_ms
|
||||
|
||||
|
||||
def torch_matmul_sigmoid(a, b):
|
||||
c = torch.matmul(a, b)
|
||||
c = torch.sigmoid(c)
|
||||
return c
|
||||
|
||||
|
||||
def benchmark_torch(M, N, K):
|
||||
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
|
||||
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
||||
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# 测速
|
||||
time_ms = benchmark(torch.matmul, a, b)
|
||||
time_ms = benchmark(torch_matmul_sigmoid, a, b)
|
||||
print(f"PyTorch 耗时: {time_ms:.3f} ms")
|
||||
return time_ms
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user