import torch import triton import triton.language as tl # so, the block_ptr is a pointer to a block of memory, which can be used to load and store data in a block-wise manner. # It is useful for implementing block-wise algorithms, such as matrix multiplication, where data is processed in blocks to improve memory access patterns and performance. # the dtype must be float32 @triton.autotune( configs=[ triton.Config( { "BLOCK_M": M, "BLOCK_N": N, "BLOCK_K": K, }, num_stages=s, num_warps=nw, ) for M in [64, 128] for N in [64, 128] for K in [32, 64] for s in [2, 4] for nw in [4, 8] ], key=["M", "N", "K"], ) @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_ab, # stride batch size stride_am, stride_ak, # A的strides (行优先) stride_bb, stride_bk, stride_bn, # B的strides (行优先) stride_cb, stride_cm, stride_cn, # C的strides BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): # 确定当前线程块处理C的哪个块 pid_b = tl.program_id(0) pid_m = tl.program_id(1) pid_n = tl.program_id(2) # 创建块指针加载A和B的对应块 a_block_ptr = tl.make_block_ptr( base=a_ptr + pid_b * stride_ab, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid_m * BLOCK_M, 0), block_shape=[BLOCK_M, BLOCK_K], order=(1, 0), # 行优先 ) b_block_ptr = tl.make_block_ptr( base=b_ptr + pid_b * stride_bb, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, pid_n * BLOCK_N), block_shape=[BLOCK_K, BLOCK_N], order=(1, 0), # 行优先 ) # 初始化累加器 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # 分块循环:沿K维度逐步计算 for k in range(0, K, BLOCK_K): a = tl.load(a_block_ptr, boundary_check=(0, 1)) # 检查K维度边界 b = tl.load(b_block_ptr, boundary_check=(0, 1)) # 计算矩阵乘法的分块累加 accumulator += tl.dot(a, b) # 移动块指针到下一个K块 a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K)) b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0)) # 创建C的块指针并存储结果 c_block_ptr = tl.make_block_ptr( base=c_ptr + pid_b * stride_cb, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0), ) accumulator = accumulator.to(tl.float16) # 转换为float16 tl.store(c_block_ptr, accumulator) def matmul(a: torch.Tensor, b: torch.Tensor): assert a.shape[2] == b.shape[1], "维度不匹配" assert a.size(0) == b.size(0), "bs must be equal" M, K = a.size(1), a.size(2) K, N = b.size(1), b.size(2) c = torch.zeros((a.size(0), M, N), device=a.device, dtype=a.dtype) # 定义每个维度的块大小 BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # 可调整的块大小 grid = lambda meta: ( a.size(0), triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]), ) matmul_kernel[grid]( a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_ab=a.stride(0), stride_am=a.stride(1), stride_ak=a.stride(2), stride_bb=b.stride(0), stride_bk=b.stride(1), stride_bn=b.stride(2), stride_cb=c.stride(0), stride_cm=c.stride(1), stride_cn=c.stride(2), ) return c # 测试 if __name__ == "__main__": for i in range(100): for M in [1024, 2048, 4096]: for N in [1024, 2048, 4096]: for K in [1024, 2048, 4096]: a = torch.randn( size=(128, M, K), device="cuda", dtype=torch.float16 ) b = torch.randn( size=(128, K, N), device="cuda", dtype=torch.float16 ) c_triton = matmul(a, b) c_torch = torch.matmul(a, b) print( f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-4)}" )