import triton import triton.language as tl import torch @triton.jit def matmul_kernel( A, B, C, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): # 获取当前块的索引 pid_m = tl.program_id(0) pid_n = tl.program_id(1) # 计算当前块的起始和结束索引 range_m = lambda: range(pid_m * BLOCK_SIZE_M, (pid_m + 1) * BLOCK_SIZE_M) range_n = lambda: range(pid_n * BLOCK_SIZE_N, (pid_n + 1) * BLOCK_SIZE_N) range_k = lambda: range(0, K) # 初始化结果矩阵块 acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # 计算矩阵乘法 for k in range_k(): a = tl.load(A + range_m()[:, None] * K + k) b = tl.load(B + k * N + range_n()[None, :]) acc += tl.dot(a, b) # 将结果写回输出矩阵 tl.store(C + range_m()[:, None] * N + range_n()[None, :], acc) def matmul(a, b): # 获取矩阵的维度 M, K = a.shape K, N = b.shape # 创建输出矩阵 c = torch.empty((M, N), device=a.device, dtype=a.dtype) # 定义块大小 BLOCK_SIZE_M = 16 BLOCK_SIZE_N = 16 BLOCK_SIZE_K = 16 # 计算网格大小 grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N)) # 启动内核 matmul_kernel[grid](a, b, c, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) return c # 示例使用 a = torch.randn(1024, 1024, device="cuda") b = torch.randn(1024, 1024, device="cuda") c = matmul(a, b) print(c)