Remove commented out code in bwd (#512)

* Remove lots of comments

* Remove unused traits
This commit is contained in:
Sophia Wisdom 2023-09-01 16:43:58 -07:00 committed by GitHub
parent dd8a754915
commit 37e32febba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1 additions and 45 deletions

View File

@ -126,45 +126,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
if (configure) return;
// dim3 grid(params.b, params.h);
// const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// dim3 grid_m(num_m_block, params.b, params.h);
// if (params.h == params.h_k) { // No multi-query or grouped-query attention (MQA/GQA)
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
// } else {
// run_flash_bwd_seqq_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
// }
// // We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check
// // for cu_seqlens_q as well.
// const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
// constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
// BOOL_SWITCH(is_even_M, IsEvenMConst, [&] {
// BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
// // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel<Kernel_traits, Is_dropout, IsCausalConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
// if constexpr(smem_size_dq_dk_dv >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
// }
// kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// });
// });
// });
// });
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
// if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
// }
// kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
}
//

View File

@ -316,17 +316,11 @@ struct Flash_bwd_kernel_traits : public Base {
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
static constexpr int kSmemdPsumCount = kBlockM;
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
static constexpr int kSmemSize1colblock = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + kSmemPSize