block的东西感觉差不多了。接着再实现一个多维的。

This commit is contained in:
longfei li 2025-04-12 13:11:54 +08:00
parent 4be98aed30
commit 9bc678f9a6
2 changed files with 133 additions and 0 deletions

123
tests/test_block_ptr.py Normal file
View File

@ -0,0 +1,123 @@
import torch
import triton
import triton.language as tl
# so, the block_ptr is a pointer to a block of memory, which can be used to load and store data in a block-wise manner.
# It is useful for implementing block-wise algorithms, such as matrix multiplication, where data is processed in blocks to improve memory access patterns and performance.
# the dtype must be float32
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak, # A的strides (行优先)
stride_bk,
stride_bn, # B的strides (行优先)
stride_cm,
stride_cn, # C的strides
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# 确定当前线程块处理C的哪个块
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# 创建块指针加载A和B的对应块
a_block_ptr = tl.make_block_ptr(
base=a_ptr,
shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_M, 0),
block_shape=[BLOCK_M, BLOCK_K],
order=(1, 0), # 行优先
)
b_block_ptr = tl.make_block_ptr(
base=b_ptr,
shape=(K, N),
strides=(stride_bk, stride_bn),
offsets=(0, pid_n * BLOCK_N),
block_shape=[BLOCK_K, BLOCK_N],
order=(1, 0), # 行优先
)
# 初始化累加器
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# 分块循环沿K维度逐步计算
for k in range(0, K, BLOCK_K):
a = tl.load(a_block_ptr, boundary_check=(0, 1)) # 检查K维度边界
b = tl.load(b_block_ptr, boundary_check=(0, 1))
# 计算矩阵乘法的分块累加
accumulator += tl.dot(a, b)
# 移动块指针到下一个K块
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))
# 创建C的块指针并存储结果
c_block_ptr = tl.make_block_ptr(
base=c_ptr,
shape=(M, N),
strides=(stride_cm, stride_cn),
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
accumulator = accumulator.to(tl.float16) # 转换为float16
tl.store(c_block_ptr, accumulator)
def matmul(a: torch.Tensor, b: torch.Tensor):
assert a.shape[1] == b.shape[0], "维度不匹配"
M, K = a.shape
K, N = b.shape
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
# 定义每个维度的块大小
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # 可调整的块大小
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
matmul_kernel[grid](
a_ptr=a,
b_ptr=b,
c_ptr=c,
M=M,
N=N,
K=K,
stride_am=a.stride(0),
stride_ak=a.stride(1),
stride_bk=b.stride(0),
stride_bn=b.stride(1),
stride_cm=c.stride(0),
stride_cn=c.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return c
# 测试
if __name__ == "__main__":
torch.manual_seed(0)
a = torch.randn(512, 256, device="cuda", dtype=torch.float16)
b = torch.randn(256, 384, device="cuda", dtype=torch.float16)
c_triton = matmul(a, b)
c_torch = torch.matmul(a, b)
print(c_triton, c_torch)
print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}")

View File

@ -52,6 +52,16 @@ configs = [
num_stages=3, # 流水线阶段数(平衡寄存器压力和延迟隐藏)
# 添加更多参数组合...
),
triton.Config(
{
"BLOCK_SIZE_M": 64, # 高维度分块
"BLOCK_SIZE_N": 32, # 宽维度分块(充分利用列方向并行性)
"BLOCK_SIZE_K": 16, # 增大 K 维度分块,提升计算/内存比
},
num_warps=8, # 每个线程块包含 8 warps256 线程)
num_stages=3, # 流水线阶段数(平衡寄存器压力和延迟隐藏)
# 添加更多参数组合...
),
]