diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 8fe200d..15842bc 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. @@ -204,7 +204,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(params.d, [&] { + HEADDIM_SWITCH(params.d, [&] { if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 run_mha_fwd_(params, stream); } else { @@ -695,25 +695,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; } -void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { - if (params.d <= 32) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 64) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 96) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 128) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 160) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 192) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 224) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 256) { - run_mha_bwd_(params, stream, configure); - } + HEADDIM_SWITCH(params.d, [&] { + run_mha_bwd_(params, stream); + }); }); } @@ -898,7 +884,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); auto launch = &run_mha_bwd; - // launch(params, stream, /*configure=*/true); auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); @@ -930,7 +915,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si } if (seqlen_q > 0) { - launch(params, stream, /*configure=*/false); + launch(params, stream); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dk_expanded.zero_(); @@ -1154,7 +1139,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); auto launch = &run_mha_bwd; - // launch(params, stream, /*configure=*/true); auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); @@ -1186,7 +1170,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } if (max_seqlen_q > 0) { - launch(params, stream, /*configure=*/false); + launch(params, stream); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dk_expanded.zero_(); diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 4a33f3d..d0852b6 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -182,4 +182,4 @@ struct Flash_bwd_params : public Flash_fwd_params { template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu index 78f4793..fe58c3f 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim128(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu index 641cac0..95fcd0f 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim128(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu index ad763a6..56a64dd 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim160(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu index 23d8145..15ac214 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim160(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu index 82dafe7..4df57d6 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim192(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu index 55dcab4..824e82c 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim192(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu index e987c00..b2b58e2 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim224(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim224(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu index 37430ba..e65cdae 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim224(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim224(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu index 6d4b10e..d044d39 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim256(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu index 0a21442..b23e1bc 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim256(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu index a7a1506..85c0846 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim32(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu index b2281ee..8a07da7 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim32(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu index 464bf9b..94c0235 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim64(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu index f2439a2..3307b5d 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim64(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu index 1234ff4..bd17d46 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim96(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu index c21f906..311a835 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim96(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 5a29f20..c5eef04 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -43,7 +43,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { } template -void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid_m(num_m_block, params.b, params.h); const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; @@ -99,13 +99,12 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, } template -void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - if (configure) return; - run_flash_bwd_seqk_parallel(params, stream, configure); +void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_flash_bwd_seqk_parallel(params, stream); } template -void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; int device; cudaGetDevice(&device); @@ -118,18 +117,18 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } else { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } } else { // 96 KB - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } }); } template -void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; int device; cudaGetDevice(&device); @@ -142,39 +141,39 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // Changing AtomLayoutMdQ from 2 to 4 takes the same time - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); // This is slightly faster. We want to split M more so we need fewer registers to store LSE. if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); // This has a lot of register spilling - // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream); } else { // if (params.h == params.h_k) { - // run_flash_bwd, Is_dropout>(params, stream, configure); - run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); // } else { // } } }); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); } template -void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; int device; cudaGetDevice(&device); @@ -188,19 +187,19 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { if constexpr(!Is_dropout) { // 92KB - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } else { // 116 KB // This is faster for dropout since we don't have many registers to spare - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } } else { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } }); } template -void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; int device; cudaGetDevice(&device); @@ -212,29 +211,29 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo } // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); } else { - // run_flash_bwd, Is_dropout>(params, stream, configure); - run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream); } - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); }); } template -void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 160; int device; cudaGetDevice(&device); @@ -246,15 +245,15 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo } BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } else { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } }); } template -void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; int device; cudaGetDevice(&device); @@ -266,23 +265,23 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo } BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 136 * 1024) { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } else { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } }); } template -void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 224; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); }); } template -void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; int device; cudaGetDevice(&device); @@ -294,9 +293,9 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bo } BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 176 * 1024) { // H100 - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } else { // A100, we don't do double buffering to save smem - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } }); } diff --git a/csrc/flash_attn/src/generate_kernels.py b/csrc/flash_attn/src/generate_kernels.py index 62e6d06..0f71002 100644 --- a/csrc/flash_attn/src/generate_kernels.py +++ b/csrc/flash_attn/src/generate_kernels.py @@ -32,8 +32,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{ - run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure); +void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream); }} """ diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index 4aa8474..ab55dd9 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -36,7 +36,7 @@ } \ }() -#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ +#define HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM <= 32) { \ constexpr static int kHeadDim = 32; \