# coding=utf-8 import triton import triton.language as tl import torch configs = [ # 针对小规模矩阵的配置(例如 M,N,K < 1024) triton.Config( {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=4, # 每个线程块的 warp 数量(4 warps = 128 threads) num_stages=3, # 流水线阶段数(对计算密集型操作更友好) ), # 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096) triton.Config( {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_warps=8, num_stages=4, ), # 针对大规模矩阵(例如 M,N,K ≥ 4096) triton.Config( {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_warps=16, num_stages=5, ), triton.Config( {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3, ), # 添加更多参数组合... ] @triton.autotune(configs=configs, key=["m", "n", "k"]) @triton.jit def matmul_kernel( a, b, c, m, n, k, stride_am, stride_an, stride_bm, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): mpid = tl.program_id(0) npid = tl.program_id(1) block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k_pad in range(0, k, BLOCK_SIZE_K): k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an) b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn) a_mask = (block_m[:, None] < m) & (k_range[None, :] < k) b_mask = (k_range[:, None] < k) & (block_n[None, :] < n) a_val = tl.load(a_ptr, mask=a_mask, other=0.0) b_val = tl.load(b_ptr, mask=b_mask, other=0.0) acc += tl.dot(a_val, b_val) acc.to(tl.bfloat16) c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn) c_mask = (block_m[:, None] < m) & (block_n[None, :] < n) tl.store(c_ptr, acc, mask=c_mask) def triton_matmul(a, b): # Define the input tensors c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16) # Define the grid size grid = lambda meta: ( triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]), triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]), ) # Launch the kernel matmul_kernel[grid]( a, b, c, m=a.size(0), n=b.size(1), k=a.size(1), stride_am=a.stride(0), stride_an=a.stride(1), stride_bm=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), ) return c def benchmark(fn, *args, **kwargs): # 预热 GPU(避免首次运行因初始化影响时间) for _ in range(10): fn(*args, **kwargs) # 创建 CUDA 事件计时 start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start_event.record() for _ in range(100): # 多次运行取平均 fn(*args, **kwargs) end_event.record() torch.cuda.synchronize() return start_event.elapsed_time(end_event) / 100 # 毫秒/次 def benchmark_triton(): # 假设你的 Triton 核函数名为 matmul_kernel,调用方法为 triton_matmul M, N, K = 4096, 4096, 4096 # 矩阵尺寸 a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) # 测速 time_ms = benchmark(triton_matmul, a, b) print(f"Triton 耗时: {time_ms:.3f} ms") def benchmark_torch(): # 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul M, N, K = 4096, 4096, 4096 # 矩阵尺寸 a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) # 测速 time_ms = benchmark(torch.matmul, a, b) print(f"PyTorch 耗时: {time_ms:.3f} ms") if __name__ == "__main__": benchmark_triton() benchmark_torch()