简单实现一个triton的矩阵乘法,感觉基本上就差不多了,可以快速用这个东西验证一些东西。

This commit is contained in:
long0x0 2025-03-28 23:20:58 +08:00
parent e33d87b0aa
commit 4774d3ef39

View File

@ -1,67 +1,146 @@
# coding=utf-8
import triton
import triton.language as tl
import torch
configs = [
# 针对小规模矩阵的配置(例如 M,N,K < 1024
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_warps=4, # 每个线程块的 warp 数量4 warps = 128 threads
num_stages=3, # 流水线阶段数(对计算密集型操作更友好)
),
# 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_warps=8,
num_stages=4,
),
# 针对大规模矩阵(例如 M,N,K ≥ 4096
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_warps=16,
num_stages=5,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
num_warps=4,
num_stages=3,
),
# 添加更多参数组合...
]
@triton.autotune(configs=configs, key=["m", "n", "k"])
@triton.jit
def matmul_kernel(
A,
B,
C,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
a,
b,
c,
m,
n,
k,
stride_am,
stride_an,
stride_bm,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
# 获取当前块的索引
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# 计算当前块的起始和结束索引
range_m = lambda: range(pid_m * BLOCK_SIZE_M, (pid_m + 1) * BLOCK_SIZE_M)
range_n = lambda: range(pid_n * BLOCK_SIZE_N, (pid_n + 1) * BLOCK_SIZE_N)
range_k = lambda: range(0, K)
# 初始化结果矩阵块
mpid = tl.program_id(0)
npid = tl.program_id(1)
block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# 计算矩阵乘法
for k in range_k():
a = tl.load(A + range_m()[:, None] * K + k)
b = tl.load(B + k * N + range_n()[None, :])
acc += tl.dot(a, b)
for k_pad in range(0, k, BLOCK_SIZE_K):
k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad
a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an)
b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn)
a_mask = (block_m[:, None] < m) & (k_range[None, :] < k)
b_mask = (k_range[:, None] < k) & (block_n[None, :] < n)
# 将结果写回输出矩阵
tl.store(C + range_m()[:, None] * N + range_n()[None, :], acc)
a_val = tl.load(a_ptr, mask=a_mask, other=0.0)
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
acc += tl.dot(a_val, b_val)
acc.to(tl.bfloat16)
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
tl.store(c_ptr, acc, mask=c_mask)
def matmul(a, b):
# 获取矩阵的维度
M, K = a.shape
K, N = b.shape
def triton_matmul(a, b):
# Define the input tensors
c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16)
# 创建输出矩阵
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 定义块大小
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 16
# 计算网格大小
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
# 启动内核
matmul_kernel[grid](a, b, c, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)
# Define the grid size
grid = lambda meta: (
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]),
)
# Launch the kernel
matmul_kernel[grid](
a,
b,
c,
m=a.size(0),
n=b.size(1),
k=a.size(1),
stride_am=a.stride(0),
stride_an=a.stride(1),
stride_bm=b.stride(0),
stride_bn=b.stride(1),
stride_cm=c.stride(0),
stride_cn=c.stride(1),
)
return c
# 示例使用
def benchmark(fn, *args, **kwargs):
# 预热 GPU避免首次运行因初始化影响时间
for _ in range(10):
fn(*args, **kwargs)
a = torch.randn(1024, 1024, device="cuda")
b = torch.randn(1024, 1024, device="cuda")
c = matmul(a, b)
print(c)
# 创建 CUDA 事件计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(100): # 多次运行取平均
fn(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / 100 # 毫秒/次
def benchmark_triton():
# 假设你的 Triton 核函数名为 matmul_kernel调用方法为 triton_matmul
M, N, K = 4096, 4096, 4096 # 矩阵尺寸
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# 测速
time_ms = benchmark(triton_matmul, a, b)
print(f"Triton 耗时: {time_ms:.3f} ms")
def benchmark_torch():
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
M, N, K = 4096, 4096, 4096 # 矩阵尺寸
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# 测速
time_ms = benchmark(torch.matmul, a, b)
print(f"PyTorch 耗时: {time_ms:.3f} ms")
if __name__ == "__main__":
benchmark_triton()
benchmark_torch()