import torch import triton import triton.language as tl @triton.jit def softmax_kernel( output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): # 获取当前程序的行索引 row_idx = tl.program_id(0) # 计算输入和输出行的起始指针 row_start_ptr = input_ptr + row_idx * input_row_stride output_row_start_ptr = output_ptr + row_idx * output_row_stride # 将输入数据加载到本地内存 row_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + row_offsets row = tl.load(input_ptrs, mask=row_offsets < n_cols, other=-float("inf")) # 计算 Softmax row_minus_max = row - tl.max(row, axis=0) # 数值稳定性:减去最大值 numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # 将结果写回输出 output_ptrs = output_row_start_ptr + row_offsets tl.store(output_ptrs, softmax_output, mask=row_offsets < n_cols) def softmax(x): n_rows, n_cols = x.shape # 分配输出张量 output = torch.empty_like(x) # 定义 GPU 内核的网格和块大小 BLOCK_SIZE = triton.next_power_of_2(n_cols) num_warps = 4 if BLOCK_SIZE >= 2048: num_warps = 8 if BLOCK_SIZE >= 4096: num_warps = 16 # 启动 Triton 内核 softmax_kernel[(n_rows,)]( output, x, x.stride(0), output.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) return output # 测试 Softmax if __name__ == "__main__": # 创建一个随机矩阵 x = torch.randn(4, 16, device="cuda") # 使用 Triton 计算 Softmax output_triton = softmax(x) # 使用 PyTorch 计算 Softmax 作为参考 output_torch = torch.softmax(x, dim=1) # 检查结果是否一致 print("Input:") print(x) print("Triton Softmax:") print(output_triton) print("PyTorch Softmax:") print(output_torch) print(f"Are close: {torch.allclose(output_triton, output_torch, atol=1e-5)}")