也是可以的,总体用起来还真是方便,还有什么稀奇的用法呢,可以继续研究一下。
This commit is contained in:
parent
9bc678f9a6
commit
baaa5dbc1c
132
tests/test_block_md.py
Normal file
132
tests/test_block_md.py
Normal file
@ -0,0 +1,132 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# so, the block_ptr is a pointer to a block of memory, which can be used to load and store data in a block-wise manner.
|
||||
# It is useful for implementing block-wise algorithms, such as matrix multiplication, where data is processed in blocks to improve memory access patterns and performance.
|
||||
# the dtype must be float32
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
BS,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_ab, # stride batch size
|
||||
stride_am,
|
||||
stride_ak, # A的strides (行优先)
|
||||
stride_bb,
|
||||
stride_bk,
|
||||
stride_bn, # B的strides (行优先)
|
||||
stride_cb,
|
||||
stride_cm,
|
||||
stride_cn, # C的strides
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
# 确定当前线程块处理C的哪个块
|
||||
pid_b = tl.program_id(0)
|
||||
pid_m = tl.program_id(1)
|
||||
pid_n = tl.program_id(2)
|
||||
# 创建块指针加载A和B的对应块
|
||||
a_block_ptr = tl.make_block_ptr(
|
||||
base=a_ptr + pid_b * stride_ab,
|
||||
shape=(M, K),
|
||||
strides=(stride_am, stride_ak),
|
||||
offsets=(pid_m * BLOCK_M, 0),
|
||||
block_shape=[BLOCK_M, BLOCK_K],
|
||||
order=(1, 0), # 行优先
|
||||
)
|
||||
|
||||
b_block_ptr = tl.make_block_ptr(
|
||||
base=b_ptr + pid_b * stride_bb,
|
||||
shape=(K, N),
|
||||
strides=(stride_bk, stride_bn),
|
||||
offsets=(0, pid_n * BLOCK_N),
|
||||
block_shape=[BLOCK_K, BLOCK_N],
|
||||
order=(1, 0), # 行优先
|
||||
)
|
||||
|
||||
# 初始化累加器
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
# 分块循环:沿K维度逐步计算
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_block_ptr, boundary_check=(0, 1)) # 检查K维度边界
|
||||
b = tl.load(b_block_ptr, boundary_check=(0, 1))
|
||||
|
||||
# 计算矩阵乘法的分块累加
|
||||
accumulator += tl.dot(a, b)
|
||||
|
||||
# 移动块指针到下一个K块
|
||||
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))
|
||||
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))
|
||||
|
||||
# 创建C的块指针并存储结果
|
||||
c_block_ptr = tl.make_block_ptr(
|
||||
base=c_ptr + pid_b * stride_cb,
|
||||
shape=(M, N),
|
||||
strides=(stride_cm, stride_cn),
|
||||
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
accumulator = accumulator.to(tl.float16) # 转换为float16
|
||||
tl.store(c_block_ptr, accumulator)
|
||||
|
||||
|
||||
def matmul(a: torch.Tensor, b: torch.Tensor):
|
||||
assert a.shape[2] == b.shape[1], "维度不匹配"
|
||||
assert a.size(0) == b.size(0), "bs must be equal"
|
||||
M, K = a.size(1), a.size(2)
|
||||
K, N = b.size(1), b.size(2)
|
||||
c = torch.zeros((a.size(0), M, N), device=a.device, dtype=a.dtype)
|
||||
|
||||
# 定义每个维度的块大小
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # 可调整的块大小
|
||||
|
||||
grid = lambda meta: (
|
||||
a.size(0),
|
||||
triton.cdiv(M, meta["BLOCK_M"]),
|
||||
triton.cdiv(N, meta["BLOCK_N"]),
|
||||
)
|
||||
|
||||
matmul_kernel[grid](
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
c_ptr=c,
|
||||
BS=a.size(0),
|
||||
M=M,
|
||||
N=N,
|
||||
K=K,
|
||||
stride_ab=a.stride(0),
|
||||
stride_am=a.stride(1),
|
||||
stride_ak=a.stride(2),
|
||||
stride_bb=b.stride(0),
|
||||
stride_bk=b.stride(1),
|
||||
stride_bn=b.stride(2),
|
||||
stride_cb=c.stride(0),
|
||||
stride_cm=c.stride(1),
|
||||
stride_cn=c.stride(2),
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_K=BLOCK_K,
|
||||
)
|
||||
|
||||
return c
|
||||
|
||||
|
||||
# 测试
|
||||
if __name__ == "__main__":
|
||||
for i in range(100):
|
||||
a = torch.randn(size=(128, 512, 256), device="cuda", dtype=torch.float16)
|
||||
b = torch.randn(size=(128, 256, 384), device="cuda", dtype=torch.float16)
|
||||
c_triton = matmul(a, b)
|
||||
c_torch = torch.matmul(a, b)
|
||||
print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}")
|
||||
Loading…
Reference in New Issue
Block a user