Remove commented out code in bwd (#512)
* Remove lots of comments * Remove unused traits
This commit is contained in:
parent
dd8a754915
commit
37e32febba
@ -126,45 +126,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
if (configure) return;
|
||||
// dim3 grid(params.b, params.h);
|
||||
// const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
// dim3 grid_m(num_m_block, params.b, params.h);
|
||||
|
||||
// if (params.h == params.h_k) { // No multi-query or grouped-query attention (MQA/GQA)
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
|
||||
// }
|
||||
|
||||
// // We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// // for cu_seqlens_q as well.
|
||||
// const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
// constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
// BOOL_SWITCH(is_even_M, IsEvenMConst, [&] {
|
||||
// BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
// // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel<Kernel_traits, Is_dropout, IsCausalConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
|
||||
// if constexpr(smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
// }
|
||||
// kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
|
||||
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
||||
// if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
// }
|
||||
// kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
|
||||
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
//
|
||||
|
||||
|
||||
@ -316,17 +316,11 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
|
||||
static constexpr int kSmemdPsumCount = kBlockM;
|
||||
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
||||
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
||||
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
||||
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
|
||||
static constexpr int kSmemSize = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
|
||||
static constexpr int kSmemSize1colblock = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
||||
|
||||
Loading…
Reference in New Issue
Block a user