torch_ext/tests/test_triton.py
2025-03-27 03:44:28 +08:00

85 lines
2.1 KiB
Python

import torch
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
# 获取当前程序的行索引
row_idx = tl.program_id(0)
# 计算输入和输出行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 将输入数据加载到本地内存
row_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + row_offsets
row = tl.load(input_ptrs, mask=row_offsets < n_cols, other=-float("inf"))
# 计算 Softmax
row_minus_max = row - tl.max(row, axis=0) # 数值稳定性:减去最大值
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# 将结果写回输出
output_ptrs = output_row_start_ptr + row_offsets
tl.store(output_ptrs, softmax_output, mask=row_offsets < n_cols)
def softmax(x):
n_rows, n_cols = x.shape
# 分配输出张量
output = torch.empty_like(x)
# 定义 GPU 内核的网格和块大小
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# 启动 Triton 内核
softmax_kernel[(n_rows,)](
output,
x,
x.stride(0),
output.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return output
# 测试 Softmax
if __name__ == "__main__":
# 创建一个随机矩阵
x = torch.randn(4, 16, device="cuda")
# 使用 Triton 计算 Softmax
output_triton = softmax(x)
# 使用 PyTorch 计算 Softmax 作为参考
output_torch = torch.softmax(x, dim=1)
# 检查结果是否一致
print("Input:")
print(x)
print("Triton Softmax:")
print(output_triton)
print("PyTorch Softmax:")
print(output_torch)
print(f"Are close: {torch.allclose(output_triton, output_torch, atol=1e-5)}")