torch_ext/tests/test_triton.py

195 lines
5.9 KiB
Python
Raw Normal View History

2025-03-29 11:56:50 +08:00
# coding=utf-8
2025-03-27 03:44:28 +08:00
import triton
import triton.language as tl
2025-03-29 11:56:50 +08:00
import torch
from loguru import logger
logger.add("triton.log", rotation="1 MB", retention="10 days", level="DEBUG")
autotune_results = []
def log_autotune(config, timing):
autotune_results.append((config, timing))
logger.info(f"Tested: {config}{timing * 1e6:.2f} us")
2025-03-29 11:56:50 +08:00
configs = [
# 针对小规模矩阵的配置(例如 M,N,K < 1024
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_warps=4, # 每个线程块的 warp 数量4 warps = 128 threads
num_stages=3, # 流水线阶段数(对计算密集型操作更友好)
),
# 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_warps=8,
num_stages=4,
),
# 针对大规模矩阵(例如 M,N,K ≥ 4096
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_warps=16,
num_stages=5,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
num_warps=4,
num_stages=3,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
num_warps=8,
num_stages=3,
),
triton.Config(
{
"BLOCK_SIZE_M": 128, # 高维度分块
"BLOCK_SIZE_N": 256, # 宽维度分块(充分利用列方向并行性)
"BLOCK_SIZE_K": 64, # 增大 K 维度分块,提升计算/内存比
},
num_warps=8, # 每个线程块包含 8 warps256 线程)
num_stages=3, # 流水线阶段数(平衡寄存器压力和延迟隐藏)
# 添加更多参数组合...
),
triton.Config(
{
"BLOCK_SIZE_M": 64, # 高维度分块
"BLOCK_SIZE_N": 32, # 宽维度分块(充分利用列方向并行性)
"BLOCK_SIZE_K": 16, # 增大 K 维度分块,提升计算/内存比
},
num_warps=8, # 每个线程块包含 8 warps256 线程)
num_stages=3, # 流水线阶段数(平衡寄存器压力和延迟隐藏)
# 添加更多参数组合...
),
2025-03-29 11:56:50 +08:00
]
@triton.autotune(configs=configs, key=["m", "n", "k"])
2025-03-27 03:44:28 +08:00
@triton.jit
2025-03-29 11:56:50 +08:00
def matmul_kernel(
a,
b,
c,
m,
n,
k,
stride_am,
stride_an,
stride_bm,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
2025-03-27 03:44:28 +08:00
):
2025-03-29 11:56:50 +08:00
mpid = tl.program_id(0)
npid = tl.program_id(1)
block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_pad in range(0, k, BLOCK_SIZE_K):
k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad
a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an)
b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn)
a_mask = (block_m[:, None] < m) & (k_range[None, :] < k)
b_mask = (k_range[:, None] < k) & (block_n[None, :] < n)
a_val = tl.load(a_ptr, mask=a_mask, other=0.0)
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
acc += tl.dot(a_val, b_val)
acc.to(tl.bfloat16)
acc = tl.sigmoid(acc)
2025-03-29 11:56:50 +08:00
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
tl.store(c_ptr, acc, mask=c_mask)
def triton_matmul(a, b):
# Define the input tensors
c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16)
# Define the grid size
grid = lambda meta: (
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]),
2025-03-27 03:44:28 +08:00
)
2025-03-29 11:56:50 +08:00
# Launch the kernel
matmul_kernel[grid](
a,
b,
c,
m=a.size(0),
n=b.size(1),
k=a.size(1),
stride_am=a.stride(0),
stride_an=a.stride(1),
stride_bm=b.stride(0),
stride_bn=b.stride(1),
stride_cm=c.stride(0),
stride_cn=c.stride(1),
)
return c
def benchmark(fn, *args, **kwargs):
# 预热 GPU避免首次运行因初始化影响时间
for _ in range(10):
fn(*args, **kwargs)
# 创建 CUDA 事件计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(100): # 多次运行取平均
fn(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / 100 # 毫秒/次
def benchmark_triton(M, N, K):
# 假设你的 Triton 核函数名为 matmul_kernel调用方法为 triton_matmul
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# 测速
time_ms = benchmark(triton_matmul, a, b)
print(f"Triton 耗时: {time_ms:.3f} ms")
return time_ms
def torch_matmul_sigmoid(a, b):
c = torch.matmul(a, b)
c = torch.sigmoid(c)
return c
2025-03-29 11:56:50 +08:00
def benchmark_torch(M, N, K):
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# 测速
time_ms = benchmark(torch_matmul_sigmoid, a, b)
2025-03-29 11:56:50 +08:00
print(f"PyTorch 耗时: {time_ms:.3f} ms")
return time_ms
2025-03-27 03:44:28 +08:00
if __name__ == "__main__":
2025-03-29 11:56:50 +08:00
for M in range(1024, 4096 * 4, 1024):
for N in range(1024, 4096 * 2, 1024):
for K in range(1024, 4096 * 2, 1024):
triton_cost_time = benchmark_triton(M, N, K)
torch_cost_time = benchmark_torch(M, N, K)
logger.info(
f"{M}, {N}, {K} triton/torch rate: {triton_cost_time / torch_cost_time}",
)