// Copyright (c) 2023, Tri Dao. #pragma once #include #include "static_switch.h" #include "flash.h" #include "flash_bwd_kernel.h" template __global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) { flash::compute_dot_do_o(params); } template __global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { flash::clear_dKVaccum(params); } template __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { flash::compute_dq_dk_dv(params); } template __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { flash::compute_dq_dk_dv_seqk_parallel(params); } template __global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) { flash::compute_dq_dk_dv_seqq_parallel(params); } template __global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) { flash::convert_dQ(params); } template __global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) { flash::convert_dKV(params); } template void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid_m(num_m_block, params.b, params.h); const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; dim3 grid_n(num_n_block, params.b, params.h); flash_bwd_dot_do_o_kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not // a multiple of kBlockN, we'll need to apply mask in the loop. const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if constexpr(smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); auto kernel_dq = &flash_bwd_convert_dq_kernel; if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); } kernel_dq<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; dim3 grid_n(num_n_block, params.b, params.h_k); flash_bwd_clear_dkvaccum_kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid_m(num_m_block, params.b, params.h); // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check // for cu_seqlens_k as well. const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; if constexpr(smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); auto kernel_dkv = &flash_bwd_convert_dkv_kernel; if constexpr(Kernel_traits::kSmemKVSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize)); } kernel_dkv<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } // template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { if (configure) return; // dim3 grid(params.b, params.h); // const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; // dim3 grid_m(num_m_block, params.b, params.h); // if (params.h == params.h_k) { // No multi-query or grouped-query attention (MQA/GQA) run_flash_bwd_seqk_parallel(params, stream, configure); // } else { // run_flash_bwd_seqq_parallel(params, stream, configure); // } // // We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check // // for cu_seqlens_q as well. // const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0; // const bool is_even_K = params.d == Kernel_traits::kHeadDim; // constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize; // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { // BOOL_SWITCH(is_even_M, IsEvenMConst, [&] { // BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { // // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // if constexpr(smem_size_dq_dk_dv >= 48 * 1024) { // C10_CUDA_CHECK(cudaFuncSetAttribute( // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); // } // kernel<<>>(params); // C10_CUDA_KERNEL_LAUNCH_CHECK(); // }); // }); // }); // }); // auto kernel_dq = &flash_bwd_convert_dq_kernel; // if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) { // C10_CUDA_CHECK(cudaFuncSetAttribute( // kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); // } // kernel_dq<<>>(params); // C10_CUDA_KERNEL_LAUNCH_CHECK(); } // template void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int Headdim = 32; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers run_flash_bwd, Is_dropout>(params, stream, configure); } else { run_flash_bwd, Is_dropout>(params, stream, configure); } } else { // 96 KB run_flash_bwd, Is_dropout>(params, stream, configure); } }); } template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int Headdim = 64; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // Changing AtomLayoutMdQ from 2 to 4 takes the same time // run_flash_bwd>(params, stream, configure); // run_flash_bwd>(params, stream, configure); // run_flash_bwd>(params, stream, configure); // run_flash_bwd, Is_dropout>(params, stream, configure); // This is slightly faster. We want to split M more so we need fewer registers to store LSE. if (max_smem_per_block >= 144 * 1024) { run_flash_bwd, Is_dropout>(params, stream, configure); // This has a lot of register spilling // run_flash_bwd, Is_dropout>(params, stream, configure); } else { // if (params.h == params.h_k) { // run_flash_bwd, Is_dropout>(params, stream, configure); run_flash_bwd, Is_dropout>(params, stream, configure); // run_flash_bwd, Is_dropout>(params, stream, configure); // run_flash_bwd, Is_dropout>(params, stream, configure); // } else { // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); // } } }); // run_flash_bwd>(params, stream, configure); // run_flash_bwd>(params, stream, configure); // run_flash_bwd>(params, stream, configure); // run_flash_bwd>(params, stream, configure); // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times // run_flash_bwd>(params, stream, configure); // run_flash_bwd>(params, stream, configure); // run_flash_bwd>(params, stream, configure); // run_flash_bwd>(params, stream, configure); } template void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int Headdim = 96; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // if (params.h == params.h_k) { if (max_smem_per_block >= 116 * 1024) { if constexpr(!Is_dropout) { // 92KB run_flash_bwd, Is_dropout>(params, stream, configure); } else { // 116 KB // This is faster for dropout since we don't have many registers to spare run_flash_bwd, Is_dropout>(params, stream, configure); } } else { run_flash_bwd, Is_dropout>(params, stream, configure); } // } else { // run_flash_bwd_seqq_parallel>(params, stream, configure); // } }); } template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int Headdim = 128; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // if (params.h == params.h_k) { // run_flash_bwd>(params, stream, configure); // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. // run_flash_bwd>(params, stream, configure); if (max_smem_per_block >= 144 * 1024) { run_flash_bwd, Is_dropout>(params, stream, configure); // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); // run_flash_bwd, Is_dropout>(params, stream, configure); // run_flash_bwd, Is_dropout>(params, stream, configure); // run_flash_bwd, Is_dropout>(params, stream, configure); } else { // run_flash_bwd, Is_dropout>(params, stream, configure); run_flash_bwd, Is_dropout>(params, stream, configure); } // run_flash_bwd>(params, stream, configure); // run_flash_bwd>(params, stream, configure); // } else { // run_flash_bwd_seqq_parallel>(params, stream, configure); // } }); } template void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int Headdim = 160; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { run_flash_bwd, Is_dropout>(params, stream, configure); } else { run_flash_bwd, Is_dropout>(params, stream, configure); } }); } template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int Headdim = 192; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 136 * 1024) { run_flash_bwd, Is_dropout>(params, stream, configure); } else { run_flash_bwd, Is_dropout>(params, stream, configure); } }); } template void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int Headdim = 224; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { run_flash_bwd, Is_dropout>(params, stream, configure); }); } template void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 176 * 1024) { // H100 run_flash_bwd, Is_dropout>(params, stream, configure); } else { // A100, we don't do double buffering to save smem run_flash_bwd, Is_dropout>(params, stream, configure); } }); }