58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
# coding=utf-8
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def softmax_kernel(a_ptr, c_ptr, stride_m, stride_n, m, n, BLOCK_SIZE: tl.constexpr):
|
|
# 计算当前线程块的起始位置
|
|
midx = tl.program_id(0)
|
|
nidx = tl.program_id(1)
|
|
m_offset = midx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
n_offset = nidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
|
|
# 创建掩码,确保不会越界访问
|
|
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
|
|
|
|
# 计算索引
|
|
index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n
|
|
|
|
# 加载输入数据
|
|
a = tl.load(a_ptr + index, mask=mask)
|
|
|
|
# 计算 softmax
|
|
max_a = tl.max(a, axis=1, keepdims=True)
|
|
exp_a = tl.exp(a - max_a)
|
|
sum_exp_a = tl.sum(exp_a, axis=1, keepdims=True)
|
|
softmax_a = exp_a / sum_exp_a
|
|
|
|
# 存储结果
|
|
tl.store(c_ptr + index, softmax_a, mask=mask)
|
|
|
|
|
|
def softmax(a, dim=0):
|
|
"""
|
|
Perform softmax operation on the input tensor `a` along the specified dimension `dim`.
|
|
This function uses Triton to accelerate the computation.
|
|
"""
|
|
m, n = a.shape
|
|
stride_m = a.stride(0)
|
|
stride_n = a.stride(1)
|
|
|
|
# Allocate output tensor
|
|
c = torch.empty_like(a)
|
|
|
|
# Launch Triton kernel
|
|
all_other_dim = 1
|
|
for i in range(len(a.size())):
|
|
if i != dim:
|
|
all_other_dim *= a.size(i)
|
|
|
|
BLOCK_SIZE = a.size(dim)
|
|
grid = (all_other_dim,)
|
|
# FIXME: The grid size should be adjusted based on the number of dimensions
|
|
softmax_kernel[grid](a, c, stride_m, stride_n, m, n, BLOCK_SIZE=BLOCK_SIZE)
|
|
|
|
return c
|