124 lines
3.4 KiB
Python
124 lines
3.4 KiB
Python
import torch
|
||
import triton
|
||
import triton.language as tl
|
||
|
||
|
||
# so, the block_ptr is a pointer to a block of memory, which can be used to load and store data in a block-wise manner.
|
||
# It is useful for implementing block-wise algorithms, such as matrix multiplication, where data is processed in blocks to improve memory access patterns and performance.
|
||
# the dtype must be float32
|
||
|
||
|
||
@triton.jit
|
||
def matmul_kernel(
|
||
a_ptr,
|
||
b_ptr,
|
||
c_ptr,
|
||
M,
|
||
N,
|
||
K,
|
||
stride_am,
|
||
stride_ak, # A的strides (行优先)
|
||
stride_bk,
|
||
stride_bn, # B的strides (行优先)
|
||
stride_cm,
|
||
stride_cn, # C的strides
|
||
BLOCK_M: tl.constexpr,
|
||
BLOCK_N: tl.constexpr,
|
||
BLOCK_K: tl.constexpr,
|
||
):
|
||
# 确定当前线程块处理C的哪个块
|
||
pid_m = tl.program_id(0)
|
||
pid_n = tl.program_id(1)
|
||
|
||
# 创建块指针加载A和B的对应块
|
||
a_block_ptr = tl.make_block_ptr(
|
||
base=a_ptr,
|
||
shape=(M, K),
|
||
strides=(stride_am, stride_ak),
|
||
offsets=(pid_m * BLOCK_M, 0),
|
||
block_shape=[BLOCK_M, BLOCK_K],
|
||
order=(1, 0), # 行优先
|
||
)
|
||
|
||
b_block_ptr = tl.make_block_ptr(
|
||
base=b_ptr,
|
||
shape=(K, N),
|
||
strides=(stride_bk, stride_bn),
|
||
offsets=(0, pid_n * BLOCK_N),
|
||
block_shape=[BLOCK_K, BLOCK_N],
|
||
order=(1, 0), # 行优先
|
||
)
|
||
|
||
# 初始化累加器
|
||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||
|
||
# 分块循环:沿K维度逐步计算
|
||
for k in range(0, K, BLOCK_K):
|
||
a = tl.load(a_block_ptr, boundary_check=(0, 1)) # 检查K维度边界
|
||
b = tl.load(b_block_ptr, boundary_check=(0, 1))
|
||
|
||
# 计算矩阵乘法的分块累加
|
||
accumulator += tl.dot(a, b)
|
||
|
||
# 移动块指针到下一个K块
|
||
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))
|
||
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))
|
||
|
||
# 创建C的块指针并存储结果
|
||
c_block_ptr = tl.make_block_ptr(
|
||
base=c_ptr,
|
||
shape=(M, N),
|
||
strides=(stride_cm, stride_cn),
|
||
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
|
||
block_shape=(BLOCK_M, BLOCK_N),
|
||
order=(1, 0),
|
||
)
|
||
accumulator = accumulator.to(tl.float16) # 转换为float16
|
||
tl.store(c_block_ptr, accumulator)
|
||
|
||
|
||
def matmul(a: torch.Tensor, b: torch.Tensor):
|
||
assert a.shape[1] == b.shape[0], "维度不匹配"
|
||
M, K = a.shape
|
||
K, N = b.shape
|
||
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||
|
||
# 定义每个维度的块大小
|
||
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # 可调整的块大小
|
||
|
||
grid = lambda meta: (
|
||
triton.cdiv(M, meta["BLOCK_M"]),
|
||
triton.cdiv(N, meta["BLOCK_N"]),
|
||
)
|
||
|
||
matmul_kernel[grid](
|
||
a_ptr=a,
|
||
b_ptr=b,
|
||
c_ptr=c,
|
||
M=M,
|
||
N=N,
|
||
K=K,
|
||
stride_am=a.stride(0),
|
||
stride_ak=a.stride(1),
|
||
stride_bk=b.stride(0),
|
||
stride_bn=b.stride(1),
|
||
stride_cm=c.stride(0),
|
||
stride_cn=c.stride(1),
|
||
BLOCK_M=BLOCK_M,
|
||
BLOCK_N=BLOCK_N,
|
||
BLOCK_K=BLOCK_K,
|
||
)
|
||
|
||
return c
|
||
|
||
|
||
# 测试
|
||
if __name__ == "__main__":
|
||
torch.manual_seed(0)
|
||
a = torch.randn(512, 256, device="cuda", dtype=torch.float16)
|
||
b = torch.randn(256, 384, device="cuda", dtype=torch.float16)
|
||
c_triton = matmul(a, b)
|
||
c_torch = torch.matmul(a, b)
|
||
print(c_triton, c_torch)
|
||
print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}")
|