85 lines
2.1 KiB
Python
85 lines
2.1 KiB
Python
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def softmax_kernel(
|
|
output_ptr,
|
|
input_ptr,
|
|
input_row_stride,
|
|
output_row_stride,
|
|
n_cols,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
# 获取当前程序的行索引
|
|
row_idx = tl.program_id(0)
|
|
|
|
# 计算输入和输出行的起始指针
|
|
row_start_ptr = input_ptr + row_idx * input_row_stride
|
|
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
|
|
|
# 将输入数据加载到本地内存
|
|
row_offsets = tl.arange(0, BLOCK_SIZE)
|
|
input_ptrs = row_start_ptr + row_offsets
|
|
row = tl.load(input_ptrs, mask=row_offsets < n_cols, other=-float("inf"))
|
|
|
|
# 计算 Softmax
|
|
row_minus_max = row - tl.max(row, axis=0) # 数值稳定性:减去最大值
|
|
numerator = tl.exp(row_minus_max)
|
|
denominator = tl.sum(numerator, axis=0)
|
|
softmax_output = numerator / denominator
|
|
|
|
# 将结果写回输出
|
|
output_ptrs = output_row_start_ptr + row_offsets
|
|
tl.store(output_ptrs, softmax_output, mask=row_offsets < n_cols)
|
|
|
|
|
|
def softmax(x):
|
|
n_rows, n_cols = x.shape
|
|
|
|
# 分配输出张量
|
|
output = torch.empty_like(x)
|
|
|
|
# 定义 GPU 内核的网格和块大小
|
|
BLOCK_SIZE = triton.next_power_of_2(n_cols)
|
|
num_warps = 4
|
|
if BLOCK_SIZE >= 2048:
|
|
num_warps = 8
|
|
if BLOCK_SIZE >= 4096:
|
|
num_warps = 16
|
|
|
|
# 启动 Triton 内核
|
|
softmax_kernel[(n_rows,)](
|
|
output,
|
|
x,
|
|
x.stride(0),
|
|
output.stride(0),
|
|
n_cols,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
num_warps=num_warps,
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
# 测试 Softmax
|
|
if __name__ == "__main__":
|
|
# 创建一个随机矩阵
|
|
x = torch.randn(4, 16, device="cuda")
|
|
|
|
# 使用 Triton 计算 Softmax
|
|
output_triton = softmax(x)
|
|
|
|
# 使用 PyTorch 计算 Softmax 作为参考
|
|
output_torch = torch.softmax(x, dim=1)
|
|
|
|
# 检查结果是否一致
|
|
print("Input:")
|
|
print(x)
|
|
print("Triton Softmax:")
|
|
print(output_triton)
|
|
print("PyTorch Softmax:")
|
|
print(output_torch)
|
|
print(f"Are close: {torch.allclose(output_triton, output_torch, atol=1e-5)}")
|