Don't need to reduce row_sum during online softmax

This commit is contained in:
Tri Dao 2024-02-20 13:33:03 -08:00
parent f45bbb4c94
commit b32efb1a4d

View File

@ -55,10 +55,10 @@ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tenso
reduce_<zero_init>(tensor, max, max_op);
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
reduce_(tensor, sum, sum_op);
thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
@ -133,7 +133,7 @@ struct Softmax {
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum(scores, row_sum);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
@ -152,15 +152,16 @@ struct Softmax {
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
Tensor scores_sum_cur = make_fragment_like(row_sum);
flash::reduce_sum(scores, scores_sum_cur);
#pragma unroll
for (int mi = 0; mi < size(row_sum); ++mi) { row_sum(mi) += scores_sum_cur(mi); }
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);