seems like merge to kernel together is faster.
This commit is contained in:
parent
93b10bb894
commit
4be98aed30
@ -5,6 +5,14 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
logger.add("triton.log", rotation="1 MB", retention="10 days", level="DEBUG")
|
logger.add("triton.log", rotation="1 MB", retention="10 days", level="DEBUG")
|
||||||
|
autotune_results = []
|
||||||
|
|
||||||
|
|
||||||
|
def log_autotune(config, timing):
|
||||||
|
autotune_results.append((config, timing))
|
||||||
|
logger.info(f"Tested: {config} → {timing * 1e6:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
configs = [
|
configs = [
|
||||||
# 针对小规模矩阵的配置(例如 M,N,K < 1024)
|
# 针对小规模矩阵的配置(例如 M,N,K < 1024)
|
||||||
triton.Config(
|
triton.Config(
|
||||||
@ -83,6 +91,7 @@ def matmul_kernel(
|
|||||||
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
|
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
|
||||||
acc += tl.dot(a_val, b_val)
|
acc += tl.dot(a_val, b_val)
|
||||||
acc.to(tl.bfloat16)
|
acc.to(tl.bfloat16)
|
||||||
|
acc = tl.sigmoid(acc)
|
||||||
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
|
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
|
||||||
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
|
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
|
||||||
tl.store(c_ptr, acc, mask=c_mask)
|
tl.store(c_ptr, acc, mask=c_mask)
|
||||||
@ -146,13 +155,19 @@ def benchmark_triton(M, N, K):
|
|||||||
return time_ms
|
return time_ms
|
||||||
|
|
||||||
|
|
||||||
|
def torch_matmul_sigmoid(a, b):
|
||||||
|
c = torch.matmul(a, b)
|
||||||
|
c = torch.sigmoid(c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
def benchmark_torch(M, N, K):
|
def benchmark_torch(M, N, K):
|
||||||
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
|
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
|
||||||
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
||||||
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
|
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# 测速
|
# 测速
|
||||||
time_ms = benchmark(torch.matmul, a, b)
|
time_ms = benchmark(torch_matmul_sigmoid, a, b)
|
||||||
print(f"PyTorch 耗时: {time_ms:.3f} ms")
|
print(f"PyTorch 耗时: {time_ms:.3f} ms")
|
||||||
return time_ms
|
return time_ms
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user