From b32efb1a4d36a890c689a15eb0848922e11690b0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 20 Feb 2024 13:33:03 -0800 Subject: [PATCH] Don't need to reduce row_sum during online softmax --- csrc/flash_attn/src/softmax.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 5bfa771..189f2e2 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -55,10 +55,10 @@ __device__ __forceinline__ void reduce_max(Tensor const& tenso reduce_(tensor, max, max_op); } -template +template __device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ SumOp sum_op; - reduce_(tensor, sum, sum_op); + thread_reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. @@ -133,7 +133,7 @@ struct Softmax { if (Is_first) { flash::template reduce_max(scores, row_max); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - flash::reduce_sum(scores, row_sum); + flash::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); @@ -152,15 +152,16 @@ struct Softmax { for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - Tensor scores_sum_cur = make_fragment_like(row_sum); - flash::reduce_sum(scores, scores_sum_cur); - #pragma unroll - for (int mi = 0; mi < size(row_sum); ++mi) { row_sum(mi) += scores_sum_cur(mi); } + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); } }; template __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);