torch_ext/test_triton.py

147 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
import triton
import triton.language as tl
import torch
configs = [
# 针对小规模矩阵的配置(例如 M,N,K < 1024
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_warps=4, # 每个线程块的 warp 数量4 warps = 128 threads
num_stages=3, # 流水线阶段数(对计算密集型操作更友好)
),
# 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_warps=8,
num_stages=4,
),
# 针对大规模矩阵(例如 M,N,K ≥ 4096
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_warps=16,
num_stages=5,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
num_warps=4,
num_stages=3,
),
# 添加更多参数组合...
]
@triton.autotune(configs=configs, key=["m", "n", "k"])
@triton.jit
def matmul_kernel(
a,
b,
c,
m,
n,
k,
stride_am,
stride_an,
stride_bm,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
mpid = tl.program_id(0)
npid = tl.program_id(1)
block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_pad in range(0, k, BLOCK_SIZE_K):
k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad
a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an)
b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn)
a_mask = (block_m[:, None] < m) & (k_range[None, :] < k)
b_mask = (k_range[:, None] < k) & (block_n[None, :] < n)
a_val = tl.load(a_ptr, mask=a_mask, other=0.0)
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
acc += tl.dot(a_val, b_val)
acc.to(tl.bfloat16)
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
tl.store(c_ptr, acc, mask=c_mask)
def triton_matmul(a, b):
# Define the input tensors
c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16)
# Define the grid size
grid = lambda meta: (
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]),
)
# Launch the kernel
matmul_kernel[grid](
a,
b,
c,
m=a.size(0),
n=b.size(1),
k=a.size(1),
stride_am=a.stride(0),
stride_an=a.stride(1),
stride_bm=b.stride(0),
stride_bn=b.stride(1),
stride_cm=c.stride(0),
stride_cn=c.stride(1),
)
return c
def benchmark(fn, *args, **kwargs):
# 预热 GPU避免首次运行因初始化影响时间
for _ in range(10):
fn(*args, **kwargs)
# 创建 CUDA 事件计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(100): # 多次运行取平均
fn(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / 100 # 毫秒/次
def benchmark_triton():
# 假设你的 Triton 核函数名为 matmul_kernel调用方法为 triton_matmul
M, N, K = 4096, 4096, 4096 # 矩阵尺寸
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# 测速
time_ms = benchmark(triton_matmul, a, b)
print(f"Triton 耗时: {time_ms:.3f} ms")
def benchmark_torch():
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
M, N, K = 4096, 4096, 4096 # 矩阵尺寸
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# 测速
time_ms = benchmark(torch.matmul, a, b)
print(f"PyTorch 耗时: {time_ms:.3f} ms")
if __name__ == "__main__":
benchmark_triton()
benchmark_torch()