156 lines
4.4 KiB
Python
156 lines
4.4 KiB
Python
import torch
|
||
import triton
|
||
import triton.language as tl
|
||
|
||
|
||
# so, the block_ptr is a pointer to a block of memory, which can be used to load and store data in a block-wise manner.
|
||
# It is useful for implementing block-wise algorithms, such as matrix multiplication, where data is processed in blocks to improve memory access patterns and performance.
|
||
# the dtype must be float32
|
||
|
||
|
||
@triton.autotune(
|
||
configs=[
|
||
triton.Config(
|
||
{
|
||
"BLOCK_M": M,
|
||
"BLOCK_N": N,
|
||
"BLOCK_K": K,
|
||
},
|
||
num_stages=s,
|
||
num_warps=nw,
|
||
)
|
||
for M in [64, 128]
|
||
for N in [64, 128]
|
||
for K in [32, 64]
|
||
for s in [2, 4]
|
||
for nw in [4, 8]
|
||
],
|
||
key=["M", "N", "K"],
|
||
)
|
||
@triton.jit
|
||
def matmul_kernel(
|
||
a_ptr,
|
||
b_ptr,
|
||
c_ptr,
|
||
M,
|
||
N,
|
||
K,
|
||
stride_ab, # stride batch size
|
||
stride_am,
|
||
stride_ak, # A的strides (行优先)
|
||
stride_bb,
|
||
stride_bk,
|
||
stride_bn, # B的strides (行优先)
|
||
stride_cb,
|
||
stride_cm,
|
||
stride_cn, # C的strides
|
||
BLOCK_M: tl.constexpr,
|
||
BLOCK_N: tl.constexpr,
|
||
BLOCK_K: tl.constexpr,
|
||
):
|
||
# 确定当前线程块处理C的哪个块
|
||
pid_b = tl.program_id(0)
|
||
pid_m = tl.program_id(1)
|
||
pid_n = tl.program_id(2)
|
||
# 创建块指针加载A和B的对应块
|
||
a_block_ptr = tl.make_block_ptr(
|
||
base=a_ptr + pid_b * stride_ab,
|
||
shape=(M, K),
|
||
strides=(stride_am, stride_ak),
|
||
offsets=(pid_m * BLOCK_M, 0),
|
||
block_shape=[BLOCK_M, BLOCK_K],
|
||
order=(1, 0), # 行优先
|
||
)
|
||
|
||
b_block_ptr = tl.make_block_ptr(
|
||
base=b_ptr + pid_b * stride_bb,
|
||
shape=(K, N),
|
||
strides=(stride_bk, stride_bn),
|
||
offsets=(0, pid_n * BLOCK_N),
|
||
block_shape=[BLOCK_K, BLOCK_N],
|
||
order=(1, 0), # 行优先
|
||
)
|
||
|
||
# 初始化累加器
|
||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||
|
||
# 分块循环:沿K维度逐步计算
|
||
for k in range(0, K, BLOCK_K):
|
||
a = tl.load(a_block_ptr, boundary_check=(0, 1)) # 检查K维度边界
|
||
b = tl.load(b_block_ptr, boundary_check=(0, 1))
|
||
|
||
# 计算矩阵乘法的分块累加
|
||
accumulator += tl.dot(a, b)
|
||
|
||
# 移动块指针到下一个K块
|
||
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))
|
||
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))
|
||
|
||
# 创建C的块指针并存储结果
|
||
c_block_ptr = tl.make_block_ptr(
|
||
base=c_ptr + pid_b * stride_cb,
|
||
shape=(M, N),
|
||
strides=(stride_cm, stride_cn),
|
||
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
|
||
block_shape=(BLOCK_M, BLOCK_N),
|
||
order=(1, 0),
|
||
)
|
||
accumulator = accumulator.to(tl.float16) # 转换为float16
|
||
tl.store(c_block_ptr, accumulator)
|
||
|
||
|
||
def matmul(a: torch.Tensor, b: torch.Tensor):
|
||
assert a.shape[2] == b.shape[1], "维度不匹配"
|
||
assert a.size(0) == b.size(0), "bs must be equal"
|
||
M, K = a.size(1), a.size(2)
|
||
K, N = b.size(1), b.size(2)
|
||
c = torch.zeros((a.size(0), M, N), device=a.device, dtype=a.dtype)
|
||
|
||
# 定义每个维度的块大小
|
||
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # 可调整的块大小
|
||
|
||
grid = lambda meta: (
|
||
a.size(0),
|
||
triton.cdiv(M, meta["BLOCK_M"]),
|
||
triton.cdiv(N, meta["BLOCK_N"]),
|
||
)
|
||
|
||
matmul_kernel[grid](
|
||
a_ptr=a,
|
||
b_ptr=b,
|
||
c_ptr=c,
|
||
M=M,
|
||
N=N,
|
||
K=K,
|
||
stride_ab=a.stride(0),
|
||
stride_am=a.stride(1),
|
||
stride_ak=a.stride(2),
|
||
stride_bb=b.stride(0),
|
||
stride_bk=b.stride(1),
|
||
stride_bn=b.stride(2),
|
||
stride_cb=c.stride(0),
|
||
stride_cm=c.stride(1),
|
||
stride_cn=c.stride(2),
|
||
)
|
||
|
||
return c
|
||
|
||
|
||
# 测试
|
||
if __name__ == "__main__":
|
||
for i in range(100):
|
||
for M in [1024, 2048, 4096]:
|
||
for N in [1024, 2048, 4096]:
|
||
for K in [1024, 2048, 4096]:
|
||
a = torch.randn(
|
||
size=(128, M, K), device="cuda", dtype=torch.float16
|
||
)
|
||
b = torch.randn(
|
||
size=(128, K, N), device="cuda", dtype=torch.float16
|
||
)
|
||
c_triton = matmul(a, b)
|
||
c_torch = torch.matmul(a, b)
|
||
print(
|
||
f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-4)}"
|
||
)
|