# coding=utf-8 import torch import triton import triton.language as tl @triton.jit def softmax_kernel(a_ptr, c_ptr, stride_m, stride_n, m, n, BLOCK_SIZE: tl.constexpr): # 计算当前线程块的起始位置 midx = tl.program_id(0) nidx = tl.program_id(1) m_offset = midx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) n_offset = nidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # 创建掩码,确保不会越界访问 mask = (m_offset[:, None] < m) & (n_offset[None, :] < n) # 计算索引 index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n # 加载输入数据 a = tl.load(a_ptr + index, mask=mask) # 计算 softmax max_a = tl.max(a, axis=1, keepdims=True) exp_a = tl.exp(a - max_a) sum_exp_a = tl.sum(exp_a, axis=1, keepdims=True) softmax_a = exp_a / sum_exp_a # 存储结果 tl.store(c_ptr + index, softmax_a, mask=mask) def softmax(a, dim=0): """ Perform softmax operation on the input tensor `a` along the specified dimension `dim`. This function uses Triton to accelerate the computation. """ m, n = a.shape stride_m = a.stride(0) stride_n = a.stride(1) # Allocate output tensor c = torch.empty_like(a) # Launch Triton kernel all_other_dim = 1 for i in range(len(a.size())): if i != dim: all_other_dim *= a.size(i) BLOCK_SIZE = a.size(dim) grid = (all_other_dim,) # FIXME: The grid size should be adjusted based on the number of dimensions softmax_kernel[grid](a, c, stride_m, stride_n, m, n, BLOCK_SIZE=BLOCK_SIZE) return c