68 lines
1.6 KiB
Python
68 lines
1.6 KiB
Python
|
|
import triton
|
||
|
|
import triton.language as tl
|
||
|
|
import torch
|
||
|
|
|
||
|
|
|
||
|
|
@triton.jit
|
||
|
|
def matmul_kernel(
|
||
|
|
A,
|
||
|
|
B,
|
||
|
|
C,
|
||
|
|
M: tl.constexpr,
|
||
|
|
N: tl.constexpr,
|
||
|
|
K: tl.constexpr,
|
||
|
|
BLOCK_SIZE_M: tl.constexpr,
|
||
|
|
BLOCK_SIZE_N: tl.constexpr,
|
||
|
|
BLOCK_SIZE_K: tl.constexpr,
|
||
|
|
):
|
||
|
|
# 获取当前块的索引
|
||
|
|
pid_m = tl.program_id(0)
|
||
|
|
pid_n = tl.program_id(1)
|
||
|
|
|
||
|
|
# 计算当前块的起始和结束索引
|
||
|
|
range_m = lambda: range(pid_m * BLOCK_SIZE_M, (pid_m + 1) * BLOCK_SIZE_M)
|
||
|
|
range_n = lambda: range(pid_n * BLOCK_SIZE_N, (pid_n + 1) * BLOCK_SIZE_N)
|
||
|
|
range_k = lambda: range(0, K)
|
||
|
|
|
||
|
|
# 初始化结果矩阵块
|
||
|
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||
|
|
|
||
|
|
# 计算矩阵乘法
|
||
|
|
for k in range_k():
|
||
|
|
a = tl.load(A + range_m()[:, None] * K + k)
|
||
|
|
b = tl.load(B + k * N + range_n()[None, :])
|
||
|
|
acc += tl.dot(a, b)
|
||
|
|
|
||
|
|
# 将结果写回输出矩阵
|
||
|
|
tl.store(C + range_m()[:, None] * N + range_n()[None, :], acc)
|
||
|
|
|
||
|
|
|
||
|
|
def matmul(a, b):
|
||
|
|
# 获取矩阵的维度
|
||
|
|
M, K = a.shape
|
||
|
|
K, N = b.shape
|
||
|
|
|
||
|
|
# 创建输出矩阵
|
||
|
|
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||
|
|
|
||
|
|
# 定义块大小
|
||
|
|
BLOCK_SIZE_M = 16
|
||
|
|
BLOCK_SIZE_N = 16
|
||
|
|
BLOCK_SIZE_K = 16
|
||
|
|
|
||
|
|
# 计算网格大小
|
||
|
|
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
|
||
|
|
|
||
|
|
# 启动内核
|
||
|
|
matmul_kernel[grid](a, b, c, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)
|
||
|
|
|
||
|
|
return c
|
||
|
|
|
||
|
|
|
||
|
|
# 示例使用
|
||
|
|
|
||
|
|
a = torch.randn(1024, 1024, device="cuda")
|
||
|
|
b = torch.randn(1024, 1024, device="cuda")
|
||
|
|
c = matmul(a, b)
|
||
|
|
print(c)
|