Remove configure in bwd kernel launch
This commit is contained in:
parent
af01244ddd
commit
ea8a25ca38
@ -1,5 +1,5 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
|
||||
@ -204,7 +204,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
HEADDIM_SWITCH(params.d, [&] {
|
||||
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
} else {
|
||||
@ -695,25 +695,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
|
||||
}
|
||||
|
||||
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
if (params.d <= 32) {
|
||||
run_mha_bwd_<elem_type, 32>(params, stream, configure);
|
||||
} else if (params.d <= 64) {
|
||||
run_mha_bwd_<elem_type, 64>(params, stream, configure);
|
||||
} else if (params.d <= 96) {
|
||||
run_mha_bwd_<elem_type, 96>(params, stream, configure);
|
||||
} else if (params.d <= 128) {
|
||||
run_mha_bwd_<elem_type, 128>(params, stream, configure);
|
||||
} else if (params.d <= 160) {
|
||||
run_mha_bwd_<elem_type, 160>(params, stream, configure);
|
||||
} else if (params.d <= 192) {
|
||||
run_mha_bwd_<elem_type, 192>(params, stream, configure);
|
||||
} else if (params.d <= 224) {
|
||||
run_mha_bwd_<elem_type, 224>(params, stream, configure);
|
||||
} else if (params.d <= 256) {
|
||||
run_mha_bwd_<elem_type, 256>(params, stream, configure);
|
||||
}
|
||||
HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_bwd_<elem_type, kHeadDim>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -898,7 +884,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
// launch(params, stream, /*configure=*/true);
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
@ -930,7 +915,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
}
|
||||
|
||||
if (seqlen_q > 0) {
|
||||
launch(params, stream, /*configure=*/false);
|
||||
launch(params, stream);
|
||||
} else {
|
||||
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
dk_expanded.zero_();
|
||||
@ -1154,7 +1139,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
// launch(params, stream, /*configure=*/true);
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
@ -1186,7 +1170,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
}
|
||||
|
||||
if (max_seqlen_q > 0) {
|
||||
launch(params, stream, /*configure=*/false);
|
||||
launch(params, stream);
|
||||
} else {
|
||||
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
dk_expanded.zero_();
|
||||
|
||||
@ -182,4 +182,4 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim128<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim160<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim192<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim224<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim224<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim256<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim32<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim64<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim96<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim96<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
||||
@ -43,7 +43,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid_m(num_m_block, params.b, params.h);
|
||||
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
@ -99,13 +99,12 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
if (configure) return;
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -118,18 +117,18 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
|
||||
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
} else { // 96 KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 64;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -142,39 +141,39 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// This has a lot of register spilling
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
// } else {
|
||||
// }
|
||||
}
|
||||
});
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
|
||||
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -188,19 +187,19 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
if constexpr(!Is_dropout) { // 92KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else { // 116 KB
|
||||
// This is faster for dropout since we don't have many registers to spare
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -212,29 +211,29 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
|
||||
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
|
||||
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -246,15 +245,15 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 192;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -266,23 +265,23 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 136 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 224;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -294,9 +293,9 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 176 * 1024) { // H100
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else { // A100, we don't do double buffering to save smem
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -32,8 +32,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params
|
||||
KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{
|
||||
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure);
|
||||
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{
|
||||
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@
|
||||
} \
|
||||
}()
|
||||
|
||||
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
|
||||
#define HEADDIM_SWITCH(HEADDIM, ...) \
|
||||
[&] { \
|
||||
if (HEADDIM <= 32) { \
|
||||
constexpr static int kHeadDim = 32; \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user