From baaa5dbc1c2a907ac41cee613a8f8943b4aeaae3 Mon Sep 17 00:00:00 2001 From: longfei li Date: Sat, 12 Apr 2025 13:25:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B9=9F=E6=98=AF=E5=8F=AF=E4=BB=A5=E7=9A=84?= =?UTF-8?q?=EF=BC=8C=E6=80=BB=E4=BD=93=E7=94=A8=E8=B5=B7=E6=9D=A5=E8=BF=98?= =?UTF-8?q?=E7=9C=9F=E6=98=AF=E6=96=B9=E4=BE=BF=EF=BC=8C=E8=BF=98=E6=9C=89?= =?UTF-8?q?=E4=BB=80=E4=B9=88=E7=A8=80=E5=A5=87=E7=9A=84=E7=94=A8=E6=B3=95?= =?UTF-8?q?=E5=91=A2=EF=BC=8C=E5=8F=AF=E4=BB=A5=E7=BB=A7=E7=BB=AD=E7=A0=94?= =?UTF-8?q?=E7=A9=B6=E4=B8=80=E4=B8=8B=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_block_md.py | 132 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 tests/test_block_md.py diff --git a/tests/test_block_md.py b/tests/test_block_md.py new file mode 100644 index 0000000..fda816d --- /dev/null +++ b/tests/test_block_md.py @@ -0,0 +1,132 @@ +import torch +import triton +import triton.language as tl + + +# so, the block_ptr is a pointer to a block of memory, which can be used to load and store data in a block-wise manner. +# It is useful for implementing block-wise algorithms, such as matrix multiplication, where data is processed in blocks to improve memory access patterns and performance. +# the dtype must be float32 + + +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + BS, + M, + N, + K, + stride_ab, # stride batch size + stride_am, + stride_ak, # A的strides (行优先) + stride_bb, + stride_bk, + stride_bn, # B的strides (行优先) + stride_cb, + stride_cm, + stride_cn, # C的strides + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + # 确定当前线程块处理C的哪个块 + pid_b = tl.program_id(0) + pid_m = tl.program_id(1) + pid_n = tl.program_id(2) + # 创建块指针加载A和B的对应块 + a_block_ptr = tl.make_block_ptr( + base=a_ptr + pid_b * stride_ab, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(pid_m * BLOCK_M, 0), + block_shape=[BLOCK_M, BLOCK_K], + order=(1, 0), # 行优先 + ) + + b_block_ptr = tl.make_block_ptr( + base=b_ptr + pid_b * stride_bb, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, pid_n * BLOCK_N), + block_shape=[BLOCK_K, BLOCK_N], + order=(1, 0), # 行优先 + ) + + # 初始化累加器 + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # 分块循环:沿K维度逐步计算 + for k in range(0, K, BLOCK_K): + a = tl.load(a_block_ptr, boundary_check=(0, 1)) # 检查K维度边界 + b = tl.load(b_block_ptr, boundary_check=(0, 1)) + + # 计算矩阵乘法的分块累加 + accumulator += tl.dot(a, b) + + # 移动块指针到下一个K块 + a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K)) + b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0)) + + # 创建C的块指针并存储结果 + c_block_ptr = tl.make_block_ptr( + base=c_ptr + pid_b * stride_cb, + shape=(M, N), + strides=(stride_cm, stride_cn), + offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + accumulator = accumulator.to(tl.float16) # 转换为float16 + tl.store(c_block_ptr, accumulator) + + +def matmul(a: torch.Tensor, b: torch.Tensor): + assert a.shape[2] == b.shape[1], "维度不匹配" + assert a.size(0) == b.size(0), "bs must be equal" + M, K = a.size(1), a.size(2) + K, N = b.size(1), b.size(2) + c = torch.zeros((a.size(0), M, N), device=a.device, dtype=a.dtype) + + # 定义每个维度的块大小 + BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # 可调整的块大小 + + grid = lambda meta: ( + a.size(0), + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + ) + + matmul_kernel[grid]( + a_ptr=a, + b_ptr=b, + c_ptr=c, + BS=a.size(0), + M=M, + N=N, + K=K, + stride_ab=a.stride(0), + stride_am=a.stride(1), + stride_ak=a.stride(2), + stride_bb=b.stride(0), + stride_bk=b.stride(1), + stride_bn=b.stride(2), + stride_cb=c.stride(0), + stride_cm=c.stride(1), + stride_cn=c.stride(2), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ) + + return c + + +# 测试 +if __name__ == "__main__": + for i in range(100): + a = torch.randn(size=(128, 512, 256), device="cuda", dtype=torch.float16) + b = torch.randn(size=(128, 256, 384), device="cuda", dtype=torch.float16) + c_triton = matmul(a, b) + c_torch = torch.matmul(a, b) + print(f"结果是否一致: {torch.allclose(c_triton, c_torch, atol=1e-2)}")