torch_ext/tests/test_block_ptr.py

124 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import triton
import triton.language as tl
# so, the block_ptr is a pointer to a block of memory, which can be used to load and store data in a block-wise manner.
# It is useful for implementing block-wise algorithms, such as matrix multiplication, where data is processed in blocks to improve memory access patterns and performance.
# the dtype must be float32
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak, # A的strides (行优先)
stride_bk,
stride_bn, # B的strides (行优先)
stride_cm,
stride_cn, # C的strides
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# 确定当前线程块处理C的哪个块
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# 创建块指针加载A和B的对应块
a_block_ptr = tl.make_block_ptr(
base=a_ptr,
shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_M, 0),
block_shape=[BLOCK_M, BLOCK_K],
order=(1, 0), # 行优先
)
b_block_ptr = tl.make_block_ptr(
base=b_ptr,
shape=(K, N),
strides=(stride_bk, stride_bn),
offsets=(0, pid_n * BLOCK_N),
block_shape=[BLOCK_K, BLOCK_N],
order=(1, 0), # 行优先
)
# 初始化累加器
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# 分块循环沿K维度逐步计算
for k in range(0, K, BLOCK_K):
a = tl.load(a_block_ptr, boundary_check=(0, 1)) # 检查K维度边界
b = tl.load(b_block_ptr, boundary_check=(0, 1))
# 计算矩阵乘法的分块累加
accumulator += tl.dot(a, b)
# 移动块指针到下一个K块
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))
# 创建C的块指针并存储结果
c_block_ptr = tl.make_block_ptr(
base=c_ptr,
shape=(M, N),
strides=(stride_cm, stride_cn),
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
accumulator = accumulator.to(tl.float16) # 转换为float16
tl.store(c_block_ptr, accumulator)
def matmul(a: torch.Tensor, b: torch.Tensor):
assert a.shape[1] == b.shape[0], "维度不匹配"
M, K = a.shape
K, N = b.shape
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
# 定义每个维度的块大小
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # 可调整的块大小
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
matmul_kernel[grid](
a_ptr=a,
b_ptr=b,
c_ptr=c,
M=M,
N=N,
K=K,
stride_am=a.stride(0),
stride_ak=a.stride(1),
stride_bk=b.stride(0),
stride_bn=b.stride(1),
stride_cm=c.stride(0),
stride_cn=c.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return c
# 测试
if __name__ == "__main__":
torch.manual_seed(0)
a = torch.randn(512, 256, device="cuda", dtype=torch.float16)
b = torch.randn(256, 384, device="cuda", dtype=torch.float16)
c_triton = matmul(a, b)
c_torch = torch.matmul(a, b)
print(c_triton, c_torch)
print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}")