torch_ext/tests/test_triton_softmax.py
2025-03-29 11:56:50 +08:00

58 lines
1.6 KiB
Python

# coding=utf-8
import torch
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(a_ptr, c_ptr, stride_m, stride_n, m, n, BLOCK_SIZE: tl.constexpr):
# 计算当前线程块的起始位置
midx = tl.program_id(0)
nidx = tl.program_id(1)
m_offset = midx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
n_offset = nidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# 创建掩码,确保不会越界访问
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
# 计算索引
index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n
# 加载输入数据
a = tl.load(a_ptr + index, mask=mask)
# 计算 softmax
max_a = tl.max(a, axis=1, keepdims=True)
exp_a = tl.exp(a - max_a)
sum_exp_a = tl.sum(exp_a, axis=1, keepdims=True)
softmax_a = exp_a / sum_exp_a
# 存储结果
tl.store(c_ptr + index, softmax_a, mask=mask)
def softmax(a, dim=0):
"""
Perform softmax operation on the input tensor `a` along the specified dimension `dim`.
This function uses Triton to accelerate the computation.
"""
m, n = a.shape
stride_m = a.stride(0)
stride_n = a.stride(1)
# Allocate output tensor
c = torch.empty_like(a)
# Launch Triton kernel
all_other_dim = 1
for i in range(len(a.size())):
if i != dim:
all_other_dim *= a.size(i)
BLOCK_SIZE = a.size(dim)
grid = (all_other_dim,)
# FIXME: The grid size should be adjusted based on the number of dimensions
softmax_kernel[grid](a, c, stride_m, stride_n, m, n, BLOCK_SIZE=BLOCK_SIZE)
return c