From 5f1ae4a34bf6a2a638b3eda231776874e6804980 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 21 Jul 2024 23:25:46 -0700 Subject: [PATCH] backwards for softcapping (#1033) * check in the two ways of approaching backwards for softcapping, both functional * prepare the softcap switch for backwards * temporary * cleanup to the way Tri prefers * calculate dtanh when copying from scores -> dtanh Tensor * no ternary operators allowed for constexpr, so just use some hack found online * fix maybe_dtanh, restore some files * restore another file * move calculate_dtanh to utils and colocate with apply_softcap * cleanup * maybe last cleanup * save for another pr * remove a stray line * fix spacing * fix an issue, and make test_flash_attn.py ready to test softcapping backwards --- csrc/flash_attn/src/flash_bwd_kernel.h | 32 ++++++++++++++++--- .../src/flash_bwd_launch_template.h | 28 ++++++++-------- csrc/flash_attn/src/flash_fwd_kernel.h | 16 +++------- csrc/flash_attn/src/utils.h | 18 +++++++++++ tests/test_flash_attn.py | 8 ++--- 5 files changed, 69 insertions(+), 33 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 0adf0d5..00cbc08 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -76,7 +76,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -471,10 +471,27 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); + if constexpr (Is_softcap) { + flash::apply_softcap(acc_s, params.softcap); + } + // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); // if (cute::thread(32, 0)) { print(scores); } + // Softcapping - calculating dTanh and scaling dS later with it + auto dtanh = ([&]{ + if constexpr (Is_softcap) { + Tensor _dtanh = make_tensor_like(scores); + flash::calculate_dtanh(scores, _dtanh, params.softcap); + return _dtanh; + } + else { + return nullptr; + } + }()); + + // Alibi if (Has_alibi) { alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16); @@ -574,7 +591,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { - dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + + float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + + if constexpr (Is_softcap) { + scaled_ds *= dtanh(mi, ni); + } + + dS(mi, ni) = scaled_ds; } } // if (cute::thread0()) { print(dS); } @@ -807,7 +831,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -817,7 +841,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index fd81c88..9168914 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -35,10 +35,10 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo #endif } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - flash::compute_dq_dk_dv_seqk_parallel(params); + flash::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -95,17 +95,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index edaf605..788f379 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -22,14 +22,6 @@ namespace flash { using namespace cute; -template -__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ - #pragma unroll - for (int i = 0; i < size(tensor); ++i) { - tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); - } -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -327,7 +319,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( @@ -393,7 +385,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); @@ -887,7 +879,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } @@ -962,7 +954,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 708aedd..b7408ec 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor const &S //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template +__forceinline__ __device__ void calculate_dtanh(Tensor &src_tensor, Tensor &dst_tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(src_tensor); ++i) { + dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index f7130cd..cd9262b 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -261,9 +261,9 @@ def attention_ref( else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if softcap > 0: - scores /= softcap + scores = scores / softcap scores = scores.tanh() - scores *= softcap + scores = scores * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: @@ -1122,7 +1122,7 @@ def test_flash_attn_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)): assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() @@ -1382,7 +1382,7 @@ def test_flash_attn_varlen_output( print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) - if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)): if kvpacked: ( dq_unpad,