torch_ext/test_triton.py
2024-11-16 19:26:54 +08:00

68 lines
1.6 KiB
Python

import triton
import triton.language as tl
import torch
@triton.jit
def matmul_kernel(
A,
B,
C,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
# 获取当前块的索引
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# 计算当前块的起始和结束索引
range_m = lambda: range(pid_m * BLOCK_SIZE_M, (pid_m + 1) * BLOCK_SIZE_M)
range_n = lambda: range(pid_n * BLOCK_SIZE_N, (pid_n + 1) * BLOCK_SIZE_N)
range_k = lambda: range(0, K)
# 初始化结果矩阵块
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# 计算矩阵乘法
for k in range_k():
a = tl.load(A + range_m()[:, None] * K + k)
b = tl.load(B + k * N + range_n()[None, :])
acc += tl.dot(a, b)
# 将结果写回输出矩阵
tl.store(C + range_m()[:, None] * N + range_n()[None, :], acc)
def matmul(a, b):
# 获取矩阵的维度
M, K = a.shape
K, N = b.shape
# 创建输出矩阵
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 定义块大小
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 16
# 计算网格大小
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
# 启动内核
matmul_kernel[grid](a, b, c, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)
return c
# 示例使用
a = torch.randn(1024, 1024, device="cuda")
b = torch.randn(1024, 1024, device="cuda")
c = matmul(a, b)
print(c)