Don't need to reduce row_sum during online softmax

This commit is contained in:
Tri Dao 2024-02-20 13:33:03 -08:00
parent f45bbb4c94
commit b32efb1a4d

View File

@ -55,10 +55,10 @@ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tenso
reduce_<zero_init>(tensor, max, max_op); reduce_<zero_init>(tensor, max, max_op);
} }
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1> template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){ __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op; SumOp<float> sum_op;
reduce_(tensor, sum, sum_op); thread_reduce_<zero_init>(tensor, sum, sum_op);
} }
// Apply the exp to all the elements. // Apply the exp to all the elements.
@ -133,7 +133,7 @@ struct Softmax {
if (Is_first) { if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max); flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum(scores, row_sum); flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else { } else {
Tensor scores_max_prev = make_fragment_like(row_max); Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev); cute::copy(row_max, scores_max_prev);
@ -152,15 +152,16 @@ struct Softmax {
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
} }
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
Tensor scores_sum_cur = make_fragment_like(row_sum); // We don't do the reduce across threads here since we don't need to use the row_sum.
flash::reduce_sum(scores, scores_sum_cur); // We do that reduce at the end when we need to normalize the softmax.
#pragma unroll flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
for (int mi = 0; mi < size(row_sum); ++mi) { row_sum(mi) += scores_sum_cur(mi); }
} }
}; };
template<bool Is_dropout=false, bool Split=false, typename Tensor0> template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum); TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);