diff --git a/test_triton.py b/test_triton.py deleted file mode 100644 index bc3621a..0000000 --- a/test_triton.py +++ /dev/null @@ -1,151 +0,0 @@ -# coding=utf-8 -import triton -import triton.language as tl -import torch - -configs = [ - # 针对小规模矩阵的配置(例如 M,N,K < 1024) - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_warps=4, # 每个线程块的 warp 数量(4 warps = 128 threads) - num_stages=3, # 流水线阶段数(对计算密集型操作更友好) - ), - # 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096) - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_warps=8, - num_stages=4, - ), - # 针对大规模矩阵(例如 M,N,K ≥ 4096) - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_warps=16, - num_stages=5, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}, - num_warps=4, - num_stages=3, - ), - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, - num_warps=16, - num_stages=5, - ), - # 添加更多参数组合... -] - - -@triton.autotune(configs=configs, key=["m", "n", "k"]) -@triton.jit -def matmul_kernel( - a, - b, - c, - m, - n, - k, - stride_am, - stride_an, - stride_bm, - stride_bn, - stride_cm, - stride_cn, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - mpid = tl.program_id(0) - npid = tl.program_id(1) - block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k_pad in range(0, k, BLOCK_SIZE_K): - k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad - a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an) - b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn) - a_mask = (block_m[:, None] < m) & (k_range[None, :] < k) - b_mask = (k_range[:, None] < k) & (block_n[None, :] < n) - - a_val = tl.load(a_ptr, mask=a_mask, other=0.0) - b_val = tl.load(b_ptr, mask=b_mask, other=0.0) - acc += tl.dot(a_val, b_val) - acc.to(tl.bfloat16) - c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn) - c_mask = (block_m[:, None] < m) & (block_n[None, :] < n) - tl.store(c_ptr, acc, mask=c_mask) - - -def triton_matmul(a, b): - # Define the input tensors - c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16) - - # Define the grid size - grid = lambda meta: ( - triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]), - triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]), - ) - - # Launch the kernel - matmul_kernel[grid]( - a, - b, - c, - m=a.size(0), - n=b.size(1), - k=a.size(1), - stride_am=a.stride(0), - stride_an=a.stride(1), - stride_bm=b.stride(0), - stride_bn=b.stride(1), - stride_cm=c.stride(0), - stride_cn=c.stride(1), - ) - return c - - -def benchmark(fn, *args, **kwargs): - # 预热 GPU(避免首次运行因初始化影响时间) - for _ in range(10): - fn(*args, **kwargs) - - # 创建 CUDA 事件计时 - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - torch.cuda.synchronize() - start_event.record() - for _ in range(100): # 多次运行取平均 - fn(*args, **kwargs) - end_event.record() - torch.cuda.synchronize() - - return start_event.elapsed_time(end_event) / 100 # 毫秒/次 - - -def benchmark_triton(): - # 假设你的 Triton 核函数名为 matmul_kernel,调用方法为 triton_matmul - M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸 - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) - - # 测速 - time_ms = benchmark(triton_matmul, a, b) - print(f"Triton 耗时: {time_ms:.3f} ms") - - -def benchmark_torch(): - # 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul - M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸 - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) - - # 测速 - time_ms = benchmark(torch.matmul, a, b) - print(f"PyTorch 耗时: {time_ms:.3f} ms") - - -if __name__ == "__main__": - benchmark_triton() - benchmark_torch() diff --git a/tests/test_triton.py b/tests/test_triton.py index 7755d16..2ea0f34 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -1,84 +1,169 @@ -import torch +# coding=utf-8 import triton import triton.language as tl +import torch +from loguru import logger + +logger.add("triton.log", rotation="1 MB", retention="10 days", level="DEBUG") +configs = [ + # 针对小规模矩阵的配置(例如 M,N,K < 1024) + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_warps=4, # 每个线程块的 warp 数量(4 warps = 128 threads) + num_stages=3, # 流水线阶段数(对计算密集型操作更友好) + ), + # 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096) + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_warps=8, + num_stages=4, + ), + # 针对大规模矩阵(例如 M,N,K ≥ 4096) + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_warps=16, + num_stages=5, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, + num_warps=8, + num_stages=3, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, # 高维度分块 + "BLOCK_SIZE_N": 256, # 宽维度分块(充分利用列方向并行性) + "BLOCK_SIZE_K": 64, # 增大 K 维度分块,提升计算/内存比 + }, + num_warps=8, # 每个线程块包含 8 warps(256 线程) + num_stages=3, # 流水线阶段数(平衡寄存器压力和延迟隐藏) + # 添加更多参数组合... + ), +] +@triton.autotune(configs=configs, key=["m", "n", "k"]) @triton.jit -def softmax_kernel( - output_ptr, - input_ptr, - input_row_stride, - output_row_stride, - n_cols, - BLOCK_SIZE: tl.constexpr, +def matmul_kernel( + a, + b, + c, + m, + n, + k, + stride_am, + stride_an, + stride_bm, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): - # 获取当前程序的行索引 - row_idx = tl.program_id(0) + mpid = tl.program_id(0) + npid = tl.program_id(1) + block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # 计算输入和输出行的起始指针 - row_start_ptr = input_ptr + row_idx * input_row_stride - output_row_start_ptr = output_ptr + row_idx * output_row_stride + for k_pad in range(0, k, BLOCK_SIZE_K): + k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad + a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an) + b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn) + a_mask = (block_m[:, None] < m) & (k_range[None, :] < k) + b_mask = (k_range[:, None] < k) & (block_n[None, :] < n) - # 将输入数据加载到本地内存 - row_offsets = tl.arange(0, BLOCK_SIZE) - input_ptrs = row_start_ptr + row_offsets - row = tl.load(input_ptrs, mask=row_offsets < n_cols, other=-float("inf")) - - # 计算 Softmax - row_minus_max = row - tl.max(row, axis=0) # 数值稳定性:减去最大值 - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - # 将结果写回输出 - output_ptrs = output_row_start_ptr + row_offsets - tl.store(output_ptrs, softmax_output, mask=row_offsets < n_cols) + a_val = tl.load(a_ptr, mask=a_mask, other=0.0) + b_val = tl.load(b_ptr, mask=b_mask, other=0.0) + acc += tl.dot(a_val, b_val) + acc.to(tl.bfloat16) + c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn) + c_mask = (block_m[:, None] < m) & (block_n[None, :] < n) + tl.store(c_ptr, acc, mask=c_mask) -def softmax(x): - n_rows, n_cols = x.shape +def triton_matmul(a, b): + # Define the input tensors + c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16) - # 分配输出张量 - output = torch.empty_like(x) - - # 定义 GPU 内核的网格和块大小 - BLOCK_SIZE = triton.next_power_of_2(n_cols) - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - # 启动 Triton 内核 - softmax_kernel[(n_rows,)]( - output, - x, - x.stride(0), - output.stride(0), - n_cols, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, + # Define the grid size + grid = lambda meta: ( + triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]), + triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]), ) - return output + # Launch the kernel + matmul_kernel[grid]( + a, + b, + c, + m=a.size(0), + n=b.size(1), + k=a.size(1), + stride_am=a.stride(0), + stride_an=a.stride(1), + stride_bm=b.stride(0), + stride_bn=b.stride(1), + stride_cm=c.stride(0), + stride_cn=c.stride(1), + ) + return c + + +def benchmark(fn, *args, **kwargs): + # 预热 GPU(避免首次运行因初始化影响时间) + for _ in range(10): + fn(*args, **kwargs) + + # 创建 CUDA 事件计时 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.synchronize() + start_event.record() + for _ in range(100): # 多次运行取平均 + fn(*args, **kwargs) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / 100 # 毫秒/次 + + +def benchmark_triton(M, N, K): + # 假设你的 Triton 核函数名为 matmul_kernel,调用方法为 triton_matmul + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + + # 测速 + time_ms = benchmark(triton_matmul, a, b) + print(f"Triton 耗时: {time_ms:.3f} ms") + return time_ms + + +def benchmark_torch(M, N, K): + # 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + + # 测速 + time_ms = benchmark(torch.matmul, a, b) + print(f"PyTorch 耗时: {time_ms:.3f} ms") + return time_ms -# 测试 Softmax if __name__ == "__main__": - # 创建一个随机矩阵 - x = torch.randn(4, 16, device="cuda") - # 使用 Triton 计算 Softmax - output_triton = softmax(x) - - # 使用 PyTorch 计算 Softmax 作为参考 - output_torch = torch.softmax(x, dim=1) - - # 检查结果是否一致 - print("Input:") - print(x) - print("Triton Softmax:") - print(output_triton) - print("PyTorch Softmax:") - print(output_torch) - print(f"Are close: {torch.allclose(output_triton, output_torch, atol=1e-5)}") + for M in range(1024, 4096 * 4, 1024): + for N in range(1024, 4096 * 2, 1024): + for K in range(1024, 4096 * 2, 1024): + triton_cost_time = benchmark_triton(M, N, K) + torch_cost_time = benchmark_torch(M, N, K) + logger.info( + f"{M}, {N}, {K} triton/torch rate: {triton_cost_time / torch_cost_time}", + ) diff --git a/tests/test_triton_softmax.py b/tests/test_triton_softmax.py new file mode 100644 index 0000000..d2fbce3 --- /dev/null +++ b/tests/test_triton_softmax.py @@ -0,0 +1,57 @@ +# coding=utf-8 +import torch +import triton +import triton.language as tl + + +@triton.jit +def softmax_kernel(a_ptr, c_ptr, stride_m, stride_n, m, n, BLOCK_SIZE: tl.constexpr): + # 计算当前线程块的起始位置 + midx = tl.program_id(0) + nidx = tl.program_id(1) + m_offset = midx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + n_offset = nidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + # 创建掩码,确保不会越界访问 + mask = (m_offset[:, None] < m) & (n_offset[None, :] < n) + + # 计算索引 + index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n + + # 加载输入数据 + a = tl.load(a_ptr + index, mask=mask) + + # 计算 softmax + max_a = tl.max(a, axis=1, keepdims=True) + exp_a = tl.exp(a - max_a) + sum_exp_a = tl.sum(exp_a, axis=1, keepdims=True) + softmax_a = exp_a / sum_exp_a + + # 存储结果 + tl.store(c_ptr + index, softmax_a, mask=mask) + + +def softmax(a, dim=0): + """ + Perform softmax operation on the input tensor `a` along the specified dimension `dim`. + This function uses Triton to accelerate the computation. + """ + m, n = a.shape + stride_m = a.stride(0) + stride_n = a.stride(1) + + # Allocate output tensor + c = torch.empty_like(a) + + # Launch Triton kernel + all_other_dim = 1 + for i in range(len(a.size())): + if i != dim: + all_other_dim *= a.size(i) + + BLOCK_SIZE = a.size(dim) + grid = (all_other_dim,) + # FIXME: The grid size should be adjusted based on the number of dimensions + softmax_kernel[grid](a, c, stride_m, stride_n, m, n, BLOCK_SIZE=BLOCK_SIZE) + + return c