也是可以的,总体用起来还真是方便,还有什么稀奇的用法呢,可以继续研究一下。

This commit is contained in:
longfei li 2025-04-12 13:25:22 +08:00
parent 9bc678f9a6
commit baaa5dbc1c

132
tests/test_block_md.py Normal file
View File

@ -0,0 +1,132 @@
import torch
import triton
import triton.language as tl
# so, the block_ptr is a pointer to a block of memory, which can be used to load and store data in a block-wise manner.
# It is useful for implementing block-wise algorithms, such as matrix multiplication, where data is processed in blocks to improve memory access patterns and performance.
# the dtype must be float32
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
BS,
M,
N,
K,
stride_ab, # stride batch size
stride_am,
stride_ak, # A的strides (行优先)
stride_bb,
stride_bk,
stride_bn, # B的strides (行优先)
stride_cb,
stride_cm,
stride_cn, # C的strides
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# 确定当前线程块处理C的哪个块
pid_b = tl.program_id(0)
pid_m = tl.program_id(1)
pid_n = tl.program_id(2)
# 创建块指针加载A和B的对应块
a_block_ptr = tl.make_block_ptr(
base=a_ptr + pid_b * stride_ab,
shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_M, 0),
block_shape=[BLOCK_M, BLOCK_K],
order=(1, 0), # 行优先
)
b_block_ptr = tl.make_block_ptr(
base=b_ptr + pid_b * stride_bb,
shape=(K, N),
strides=(stride_bk, stride_bn),
offsets=(0, pid_n * BLOCK_N),
block_shape=[BLOCK_K, BLOCK_N],
order=(1, 0), # 行优先
)
# 初始化累加器
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# 分块循环沿K维度逐步计算
for k in range(0, K, BLOCK_K):
a = tl.load(a_block_ptr, boundary_check=(0, 1)) # 检查K维度边界
b = tl.load(b_block_ptr, boundary_check=(0, 1))
# 计算矩阵乘法的分块累加
accumulator += tl.dot(a, b)
# 移动块指针到下一个K块
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))
# 创建C的块指针并存储结果
c_block_ptr = tl.make_block_ptr(
base=c_ptr + pid_b * stride_cb,
shape=(M, N),
strides=(stride_cm, stride_cn),
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
accumulator = accumulator.to(tl.float16) # 转换为float16
tl.store(c_block_ptr, accumulator)
def matmul(a: torch.Tensor, b: torch.Tensor):
assert a.shape[2] == b.shape[1], "维度不匹配"
assert a.size(0) == b.size(0), "bs must be equal"
M, K = a.size(1), a.size(2)
K, N = b.size(1), b.size(2)
c = torch.zeros((a.size(0), M, N), device=a.device, dtype=a.dtype)
# 定义每个维度的块大小
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # 可调整的块大小
grid = lambda meta: (
a.size(0),
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
matmul_kernel[grid](
a_ptr=a,
b_ptr=b,
c_ptr=c,
BS=a.size(0),
M=M,
N=N,
K=K,
stride_ab=a.stride(0),
stride_am=a.stride(1),
stride_ak=a.stride(2),
stride_bb=b.stride(0),
stride_bk=b.stride(1),
stride_bn=b.stride(2),
stride_cb=c.stride(0),
stride_cm=c.stride(1),
stride_cn=c.stride(2),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return c
# 测试
if __name__ == "__main__":
for i in range(100):
a = torch.randn(size=(128, 512, 256), device="cuda", dtype=torch.float16)
b = torch.randn(size=(128, 256, 384), device="cuda", dtype=torch.float16)
c_triton = matmul(a, b)
c_torch = torch.matmul(a, b)
print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}")