From 6f706eff965e79f5bbeec2ac2ff8dba42370123e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 15 Jan 2024 15:25:02 -0800 Subject: [PATCH] Make Softmax an object --- csrc/flash_attn/src/flash_fwd_kernel.h | 55 ++++------------- csrc/flash_attn/src/softmax.h | 81 +++++++++++++++++--------- 2 files changed, 67 insertions(+), 69 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index e91bbb6..ee576f6 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -180,10 +180,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - // TODO: this might need to change if we change the mma instruction in SM70 - Tensor scores_max = make_tensor(Shape(acc_o)>>{}); - Tensor scores_sum = make_fragment_like(scores_max); - // // PREDICATES // @@ -267,6 +263,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); + flash::Softmax<2 * size<1>(acc_o)> softmax; + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); @@ -357,8 +355,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? flash::softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : flash::softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(scores); @@ -435,7 +433,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); } - flash::softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; @@ -461,20 +459,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Epilogue - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - Tensor lse = make_fragment_like(scores_sum); - #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } - } - - // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = flash::convert_type(acc_o); @@ -685,11 +670,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - // TODO: this might need to change if we change the mma instruction in SM70 - Tensor scores_max = make_tensor(Shape(acc_o)>>{}); - Tensor scores_sum = make_fragment_like(scores_max); - - // // PREDICATES // @@ -862,6 +842,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons clear(acc_o); + flash::Softmax<2 * size<1>(acc_o)> softmax; + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); @@ -939,8 +921,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? flash::softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : flash::softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert scores from fp32 to fp16/bf16 @@ -1002,7 +984,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons params.window_size_left, params.window_size_right ); } - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -1014,21 +996,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Epilogue - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - // if (cute::thread0()) { print(acc_o_rowcol); } - Tensor lse = make_fragment_like(scores_sum); - #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = inv_sum; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } - } + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } - // if (cute::thread0()) { print(acc_o_rowcol); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 165acb8..749dffa 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -117,35 +117,64 @@ inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tens //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, - Tensor2 &acc_o, float softmax_scale_log2) { - if (Is_first) { - flash::template reduce_max(scores, scores_max); - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - flash::reduce_sum(scores, scores_sum); - } else { - Tensor scores_max_prev = make_fragment_like(scores_max); - cute::copy(scores_max, scores_max_prev); - flash::template reduce_max(scores, scores_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - #pragma unroll - for (int mi = 0; mi < size(scores_max); ++mi) { - float scores_max_cur = !Check_inf - ? scores_max(mi) - : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - scores_sum(mi) *= scores_scale; +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + inline __device__ Softmax() {}; + + template + inline __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + Tensor scores_sum_cur = make_fragment_like(row_sum); + flash::reduce_sum(scores, scores_sum_cur); + #pragma unroll + for (int mi = 0; mi < size(row_sum); ++mi) { row_sum(mi) += scores_sum_cur(mi); } } - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - Tensor scores_sum_cur = make_fragment_like(scores_sum); - flash::reduce_sum(scores, scores_sum_cur); + }; + + template + inline __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll - for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } - } + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; }; } // namespace flash