diff --git a/test_triton.py b/test_triton.py index e2d985b..767d803 100644 --- a/test_triton.py +++ b/test_triton.py @@ -1,67 +1,146 @@ +# coding=utf-8 import triton import triton.language as tl import torch +configs = [ + # 针对小规模矩阵的配置(例如 M,N,K < 1024) + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_warps=4, # 每个线程块的 warp 数量(4 warps = 128 threads) + num_stages=3, # 流水线阶段数(对计算密集型操作更友好) + ), + # 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096) + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_warps=8, + num_stages=4, + ), + # 针对大规模矩阵(例如 M,N,K ≥ 4096) + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_warps=16, + num_stages=5, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}, + num_warps=4, + num_stages=3, + ), + # 添加更多参数组合... +] + +@triton.autotune(configs=configs, key=["m", "n", "k"]) @triton.jit def matmul_kernel( - A, - B, - C, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, + a, + b, + c, + m, + n, + k, + stride_am, + stride_an, + stride_bm, + stride_bn, + stride_cm, + stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): - # 获取当前块的索引 - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - # 计算当前块的起始和结束索引 - range_m = lambda: range(pid_m * BLOCK_SIZE_M, (pid_m + 1) * BLOCK_SIZE_M) - range_n = lambda: range(pid_n * BLOCK_SIZE_N, (pid_n + 1) * BLOCK_SIZE_N) - range_k = lambda: range(0, K) - - # 初始化结果矩阵块 + mpid = tl.program_id(0) + npid = tl.program_id(1) + block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # 计算矩阵乘法 - for k in range_k(): - a = tl.load(A + range_m()[:, None] * K + k) - b = tl.load(B + k * N + range_n()[None, :]) - acc += tl.dot(a, b) + for k_pad in range(0, k, BLOCK_SIZE_K): + k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad + a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an) + b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn) + a_mask = (block_m[:, None] < m) & (k_range[None, :] < k) + b_mask = (k_range[:, None] < k) & (block_n[None, :] < n) - # 将结果写回输出矩阵 - tl.store(C + range_m()[:, None] * N + range_n()[None, :], acc) + a_val = tl.load(a_ptr, mask=a_mask, other=0.0) + b_val = tl.load(b_ptr, mask=b_mask, other=0.0) + acc += tl.dot(a_val, b_val) + acc.to(tl.bfloat16) + c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn) + c_mask = (block_m[:, None] < m) & (block_n[None, :] < n) + tl.store(c_ptr, acc, mask=c_mask) -def matmul(a, b): - # 获取矩阵的维度 - M, K = a.shape - K, N = b.shape +def triton_matmul(a, b): + # Define the input tensors + c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16) - # 创建输出矩阵 - c = torch.empty((M, N), device=a.device, dtype=a.dtype) - - # 定义块大小 - BLOCK_SIZE_M = 16 - BLOCK_SIZE_N = 16 - BLOCK_SIZE_K = 16 - - # 计算网格大小 - grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N)) - - # 启动内核 - matmul_kernel[grid](a, b, c, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) + # Define the grid size + grid = lambda meta: ( + triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]), + triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]), + ) + # Launch the kernel + matmul_kernel[grid]( + a, + b, + c, + m=a.size(0), + n=b.size(1), + k=a.size(1), + stride_am=a.stride(0), + stride_an=a.stride(1), + stride_bm=b.stride(0), + stride_bn=b.stride(1), + stride_cm=c.stride(0), + stride_cn=c.stride(1), + ) return c -# 示例使用 +def benchmark(fn, *args, **kwargs): + # 预热 GPU(避免首次运行因初始化影响时间) + for _ in range(10): + fn(*args, **kwargs) -a = torch.randn(1024, 1024, device="cuda") -b = torch.randn(1024, 1024, device="cuda") -c = matmul(a, b) -print(c) + # 创建 CUDA 事件计时 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.synchronize() + start_event.record() + for _ in range(100): # 多次运行取平均 + fn(*args, **kwargs) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / 100 # 毫秒/次 + + +def benchmark_triton(): + # 假设你的 Triton 核函数名为 matmul_kernel,调用方法为 triton_matmul + M, N, K = 4096, 4096, 4096 # 矩阵尺寸 + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + + # 测速 + time_ms = benchmark(triton_matmul, a, b) + print(f"Triton 耗时: {time_ms:.3f} ms") + + +def benchmark_torch(): + # 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul + M, N, K = 4096, 4096, 4096 # 矩阵尺寸 + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + + # 测速 + time_ms = benchmark(torch.matmul, a, b) + print(f"PyTorch 耗时: {time_ms:.3f} ms") + + +if __name__ == "__main__": + benchmark_triton() + benchmark_torch()