torch_ext/tests/test_triton.py

195 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
import triton
import triton.language as tl
import torch
from loguru import logger
logger.add("triton.log", rotation="1 MB", retention="10 days", level="DEBUG")
autotune_results = []
def log_autotune(config, timing):
autotune_results.append((config, timing))
logger.info(f"Tested: {config}{timing * 1e6:.2f} us")
configs = [
# 针对小规模矩阵的配置(例如 M,N,K < 1024
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_warps=4, # 每个线程块的 warp 数量4 warps = 128 threads
num_stages=3, # 流水线阶段数(对计算密集型操作更友好)
),
# 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_warps=8,
num_stages=4,
),
# 针对大规模矩阵(例如 M,N,K ≥ 4096
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_warps=16,
num_stages=5,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
num_warps=4,
num_stages=3,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
num_warps=8,
num_stages=3,
),
triton.Config(
{
"BLOCK_SIZE_M": 128, # 高维度分块
"BLOCK_SIZE_N": 256, # 宽维度分块(充分利用列方向并行性)
"BLOCK_SIZE_K": 64, # 增大 K 维度分块,提升计算/内存比
},
num_warps=8, # 每个线程块包含 8 warps256 线程)
num_stages=3, # 流水线阶段数(平衡寄存器压力和延迟隐藏)
# 添加更多参数组合...
),
triton.Config(
{
"BLOCK_SIZE_M": 64, # 高维度分块
"BLOCK_SIZE_N": 32, # 宽维度分块(充分利用列方向并行性)
"BLOCK_SIZE_K": 16, # 增大 K 维度分块,提升计算/内存比
},
num_warps=8, # 每个线程块包含 8 warps256 线程)
num_stages=3, # 流水线阶段数(平衡寄存器压力和延迟隐藏)
# 添加更多参数组合...
),
]
@triton.autotune(configs=configs, key=["m", "n", "k"])
@triton.jit
def matmul_kernel(
a,
b,
c,
m,
n,
k,
stride_am,
stride_an,
stride_bm,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
mpid = tl.program_id(0)
npid = tl.program_id(1)
block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_pad in range(0, k, BLOCK_SIZE_K):
k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad
a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an)
b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn)
a_mask = (block_m[:, None] < m) & (k_range[None, :] < k)
b_mask = (k_range[:, None] < k) & (block_n[None, :] < n)
a_val = tl.load(a_ptr, mask=a_mask, other=0.0)
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
acc += tl.dot(a_val, b_val)
acc.to(tl.bfloat16)
acc = tl.sigmoid(acc)
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
tl.store(c_ptr, acc, mask=c_mask)
def triton_matmul(a, b):
# Define the input tensors
c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16)
# Define the grid size
grid = lambda meta: (
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]),
)
# Launch the kernel
matmul_kernel[grid](
a,
b,
c,
m=a.size(0),
n=b.size(1),
k=a.size(1),
stride_am=a.stride(0),
stride_an=a.stride(1),
stride_bm=b.stride(0),
stride_bn=b.stride(1),
stride_cm=c.stride(0),
stride_cn=c.stride(1),
)
return c
def benchmark(fn, *args, **kwargs):
# 预热 GPU避免首次运行因初始化影响时间
for _ in range(10):
fn(*args, **kwargs)
# 创建 CUDA 事件计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(100): # 多次运行取平均
fn(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / 100 # 毫秒/次
def benchmark_triton(M, N, K):
# 假设你的 Triton 核函数名为 matmul_kernel调用方法为 triton_matmul
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# 测速
time_ms = benchmark(triton_matmul, a, b)
print(f"Triton 耗时: {time_ms:.3f} ms")
return time_ms
def torch_matmul_sigmoid(a, b):
c = torch.matmul(a, b)
c = torch.sigmoid(c)
return c
def benchmark_torch(M, N, K):
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# 测速
time_ms = benchmark(torch_matmul_sigmoid, a, b)
print(f"PyTorch 耗时: {time_ms:.3f} ms")
return time_ms
if __name__ == "__main__":
for M in range(1024, 4096 * 4, 1024):
for N in range(1024, 4096 * 2, 1024):
for K in range(1024, 4096 * 2, 1024):
triton_cost_time = benchmark_triton(M, N, K)
torch_cost_time = benchmark_torch(M, N, K)
logger.info(
f"{M}, {N}, {K} triton/torch rate: {triton_cost_time / torch_cost_time}",
)