diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 9436b22..84cb71f 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -112,6 +112,9 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. @@ -122,7 +125,16 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.window_size_left = window_size_left; params.window_size_right = window_size_right; + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), + "This flash attention build does not support local attention."); + #endif + params.is_seqlens_k_cumulative = true; + + #ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + #endif } void set_params_dgrad(Flash_bwd_params ¶ms, @@ -282,6 +294,25 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, } } +void set_params_alibi(Flash_fwd_params ¶ms, c10::optional &alibi_slopes_, int batch_size, int num_heads){ +#ifdef FLASHATTENTION_DISABLE_ALIBI + TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi."); + params.alibi_slopes_ptr = nullptr; +#else + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); + params.alibi_slopes_ptr = alibi_slopes.data_ptr(); + params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } else { + params.alibi_slopes_ptr = nullptr; + } +#endif +} + std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size @@ -435,17 +466,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size params.philox_args = gen->philox_cuda_state(counter_offset); } - if (alibi_slopes_.has_value()) { - auto alibi_slopes = alibi_slopes_.value(); - TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); - CHECK_DEVICE(alibi_slopes); - TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); - TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); - params.alibi_slopes_ptr = alibi_slopes.data_ptr(); - params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; - } else { - params.alibi_slopes_ptr = nullptr; - } + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -657,17 +678,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s params.philox_args = gen->philox_cuda_state(counter_offset); } - if (alibi_slopes_.has_value()) { - auto alibi_slopes = alibi_slopes_.value(); - TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); - CHECK_DEVICE(alibi_slopes); - TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); - TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); - params.alibi_slopes_ptr = alibi_slopes.data_ptr(); - params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; - } else { - params.alibi_slopes_ptr = nullptr; - } + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -724,6 +735,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional gen_, c10::optional &rng_state) { + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -903,17 +917,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si params.rng_state[1] = std::get<1>(seeds); } - if (alibi_slopes_.has_value()) { - auto alibi_slopes = alibi_slopes_.value(); - TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); - CHECK_DEVICE(alibi_slopes); - TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); - TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); - params.alibi_slopes_ptr = alibi_slopes.data_ptr(); - params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; - } else { - params.alibi_slopes_ptr = nullptr; - } + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (seqlen_q > 0) { launch(params, stream); @@ -963,6 +967,10 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size c10::optional gen_, c10::optional &rng_state) { + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -1158,17 +1166,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size params.rng_state[1] = std::get<1>(seeds); } - if (alibi_slopes_.has_value()) { - auto alibi_slopes = alibi_slopes_.value(); - TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); - CHECK_DEVICE(alibi_slopes); - TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); - TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); - params.alibi_slopes_ptr = alibi_slopes.data_ptr(); - params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; - } else { - params.alibi_slopes_ptr = nullptr; - } + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (max_seqlen_q > 0) { launch(params, stream); @@ -1435,17 +1433,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } params.page_block_size = page_block_size; - if (alibi_slopes_.has_value()) { - auto alibi_slopes = alibi_slopes_.value(); - TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); - CHECK_DEVICE(alibi_slopes); - TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); - TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); - params.alibi_slopes_ptr = alibi_slopes.data_ptr(); - params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; - } else { - params.alibi_slopes_ptr = nullptr; - } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); auto stream = at::cuda::getCurrentCUDAStream().stream(); // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index c5eef04..a863875 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -69,9 +69,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { - BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false @@ -100,7 +100,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { +#ifndef FLASHATTENTION_DISABLE_BACKWARD run_flash_bwd_seqk_parallel(params, stream); +#endif } template @@ -114,7 +116,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers run_flash_bwd, Is_dropout>(params, stream); @@ -139,7 +141,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // Changing AtomLayoutMdQ from 2 to 4 takes the same time // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); @@ -184,7 +186,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { if constexpr(!Is_dropout) { // 92KB run_flash_bwd, Is_dropout>(params, stream); @@ -210,7 +212,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // run_flash_bwd>(params, stream); // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. @@ -243,7 +245,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { run_flash_bwd, Is_dropout>(params, stream); } else { @@ -263,7 +265,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 136 * 1024) { run_flash_bwd, Is_dropout>(params, stream); } else { @@ -275,7 +277,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { template void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 224; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { run_flash_bwd, Is_dropout>(params, stream); }); } @@ -291,7 +293,7 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 176 * 1024) { // H100 run_flash_bwd, Is_dropout>(params, stream); } else { // A100, we don't do double buffering to save smem diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 5852598..1d30d9e 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -42,10 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { // Will only return softmax if dropout, to reduce compilation time. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If return_softmax, set IsEvenMNConst to false to reduce number of templates @@ -83,11 +83,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false @@ -113,7 +113,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If headdim is divisible by 64, then we set kBlockM = 8, etc. constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { @@ -147,7 +147,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); @@ -157,7 +157,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower @@ -181,7 +181,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { @@ -207,7 +207,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -244,7 +244,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 160; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -272,7 +272,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -300,7 +300,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -331,7 +331,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index ab55dd9..ca12fa1 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -14,6 +14,7 @@ /// some_function(...); /// }); /// ``` + #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ @@ -25,6 +26,46 @@ } \ }() +#ifdef FLASHATTENTION_DISABLE_DROPOUT + #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI + #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + #define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else + #define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define LOCAL_SWITCH BOOL_SWITCH +#endif + #define FP16_SWITCH(COND, ...) \ [&] { \ if (COND) { \