From 37e32febba13fd0208c275f1f34cb63b1292cf7d Mon Sep 17 00:00:00 2001 From: Sophia Wisdom Date: Fri, 1 Sep 2023 16:43:58 -0700 Subject: [PATCH] Remove commented out code in bwd (#512) * Remove lots of comments * Remove unused traits --- .../src/flash_bwd_launch_template.h | 40 +------------------ csrc/flash_attn/src/kernel_traits.h | 6 --- 2 files changed, 1 insertion(+), 45 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index e4c36ea..81b2e4c 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -126,45 +126,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { if (configure) return; - // dim3 grid(params.b, params.h); - // const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - // dim3 grid_m(num_m_block, params.b, params.h); - - // if (params.h == params.h_k) { // No multi-query or grouped-query attention (MQA/GQA) - run_flash_bwd_seqk_parallel(params, stream, configure); - // } else { - // run_flash_bwd_seqq_parallel(params, stream, configure); - // } - - // // We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check - // // for cu_seqlens_q as well. - // const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0; - // const bool is_even_K = params.d == Kernel_traits::kHeadDim; - // constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize; - // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { - // BOOL_SWITCH(is_even_M, IsEvenMConst, [&] { - // BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - // // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // if constexpr(smem_size_dq_dk_dv >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - // } - // kernel<<>>(params); - // C10_CUDA_KERNEL_LAUNCH_CHECK(); - // }); - // }); - // }); - // }); - - // auto kernel_dq = &flash_bwd_convert_dq_kernel; - // if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); - // } - // kernel_dq<<>>(params); - // C10_CUDA_KERNEL_LAUNCH_CHECK(); + run_flash_bwd_seqk_parallel(params, stream, configure); } // diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index d5e8a76..c7f2e4b 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -316,17 +316,11 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); - static constexpr int kSmemdPsumCount = kBlockM; static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); - static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); - static constexpr int kSmemSize = kSmemQdOSize - + (!Is_V_in_regs - ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize