/****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include "philox.cuh" #include "utils.h" namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); mi++) { summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { summary(mi) = op(summary(mi), tensor(mi, ni)); } } } template __device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++){ dst(i) = Allreduce<4>::run(src(i), op); } } template __device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template __device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ MaxOp max_op; reduce_(tensor, max, max_op); } template __device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ SumOp sum_op; reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. template inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); } } } // Apply the exp to all the elements. template inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { MaxOp max_op; max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { max(mi) = max_op(max(mi), tensor(mi, ni)); } max(mi) = Allreduce<4>::run(max(mi), max_op); // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; sum(mi) = 0; #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); sum(mi) += tensor(mi, ni); } SumOp sum_op; sum(mi) = Allreduce<4>::run(sum(mi), sum_op); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, Tensor2 &acc_o, float softmax_scale_log2) { if (Is_first) { flash::template reduce_max(scores, scores_max); flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); flash::reduce_sum(scores, scores_sum); } else { Tensor scores_max_prev = make_fragment_like(scores_max); cute::copy(scores_max, scores_max_prev); flash::template reduce_max(scores, scores_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); #pragma unroll for (int mi = 0; mi < size(scores_max); ++mi) { float scores_max_cur = !Check_inf ? scores_max(mi) : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); scores_sum(mi) *= scores_scale; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } } flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); Tensor scores_sum_cur = make_fragment_like(scores_sum); flash::reduce_sum(scores, scores_sum_cur); #pragma unroll for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } } }; } // namespace flash