diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 88788a9..964386b 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -36,7 +36,8 @@ #include #endif -#include +#include +#include #include diff --git a/csrc/flash_attn/src/fmha_bwd_hdim128.cu b/csrc/flash_attn/src/fmha_bwd_hdim128.cu index 0d17324..d171b3c 100644 --- a/csrc/flash_attn/src/fmha_bwd_hdim128.cu +++ b/csrc/flash_attn/src/fmha_bwd_hdim128.cu @@ -5,9 +5,8 @@ #include "fmha_bwd_launch_template.h" void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - // work around for MSVC issue - FP16_SWITCH(params.is_bf16, [&] { + FP16_SWITCH(params.is_bf16, ({ using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; run_fmha_bwd_loop(params, stream, configure); - }); + })); } \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_bwd_hdim32.cu b/csrc/flash_attn/src/fmha_bwd_hdim32.cu index eafec98..06c6e48 100644 --- a/csrc/flash_attn/src/fmha_bwd_hdim32.cu +++ b/csrc/flash_attn/src/fmha_bwd_hdim32.cu @@ -5,8 +5,7 @@ #include "fmha_bwd_launch_template.h" void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - // work around for MSVC issue - FP16_SWITCH(params.is_bf16, [&] { + FP16_SWITCH(params.is_bf16, ({ if (params.seqlen_k == 128) { using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; run_fmha_bwd_loop(params, stream, configure); @@ -14,5 +13,5 @@ void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; run_fmha_bwd_loop(params, stream, configure); } - }); + })); } \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_bwd_hdim64.cu b/csrc/flash_attn/src/fmha_bwd_hdim64.cu index faa7595..7dd8650 100644 --- a/csrc/flash_attn/src/fmha_bwd_hdim64.cu +++ b/csrc/flash_attn/src/fmha_bwd_hdim64.cu @@ -5,8 +5,7 @@ #include "fmha_bwd_launch_template.h" void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - // work around for MSVC issue - FP16_SWITCH(params.is_bf16, [&] { + FP16_SWITCH(params.is_bf16, ({ auto dprops = at::cuda::getCurrentDeviceProperties(); if (params.seqlen_k == 128) { using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; @@ -27,5 +26,5 @@ void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b run_fmha_bwd_loop(params, stream, configure); } } - }); + })); } \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_bwd_launch_template.h b/csrc/flash_attn/src/fmha_bwd_launch_template.h index ffdde36..1e5be02 100644 --- a/csrc/flash_attn/src/fmha_bwd_launch_template.h +++ b/csrc/flash_attn/src/fmha_bwd_launch_template.h @@ -3,7 +3,6 @@ #pragma once #include "static_switch.h" -#include "fp16_switch.h" #include "fmha.h" #include "fmha_dgrad_kernel_1xN_loop.h" @@ -62,7 +61,7 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + BOOL_SWITCH(is_dropout, IsDropoutConst, ({ auto kernel = params.is_causal ? &fmha_bwd_dq_dk_dv_loop_kernel : &fmha_bwd_dq_dk_dv_loop_kernel; @@ -111,5 +110,5 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo kernel_seqparallel<<>>(params); } FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); + })); } diff --git a/csrc/flash_attn/src/fmha_fwd_hdim128.cu b/csrc/flash_attn/src/fmha_fwd_hdim128.cu index b434310..8d4477f 100644 --- a/csrc/flash_attn/src/fmha_fwd_hdim128.cu +++ b/csrc/flash_attn/src/fmha_fwd_hdim128.cu @@ -5,8 +5,8 @@ #include "fmha_fwd_launch_template.h" void run_fmha_fwd_hdim128(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, [&] { + FP16_SWITCH(launch_params.params.is_bf16, ({ using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; run_fmha_fwd_loop(launch_params); - }); + })); } \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_fwd_hdim32.cu b/csrc/flash_attn/src/fmha_fwd_hdim32.cu index b59a6a7..5fa48eb 100644 --- a/csrc/flash_attn/src/fmha_fwd_hdim32.cu +++ b/csrc/flash_attn/src/fmha_fwd_hdim32.cu @@ -5,7 +5,7 @@ #include "fmha_fwd_launch_template.h" void run_fmha_fwd_hdim32(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, [&] { + FP16_SWITCH(launch_params.params.is_bf16, ({ if (launch_params.params.seqlen_k == 128) { using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; run_fmha_fwd_loop(launch_params); @@ -13,5 +13,5 @@ void run_fmha_fwd_hdim32(Launch_params &launch_params) { using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; run_fmha_fwd_loop(launch_params); } - }); + })); } \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_fwd_hdim64.cu b/csrc/flash_attn/src/fmha_fwd_hdim64.cu index 15e8797..9776c6d 100644 --- a/csrc/flash_attn/src/fmha_fwd_hdim64.cu +++ b/csrc/flash_attn/src/fmha_fwd_hdim64.cu @@ -5,7 +5,7 @@ #include "fmha_fwd_launch_template.h" void run_fmha_fwd_hdim64(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, [&] { + FP16_SWITCH(launch_params.params.is_bf16, ({ if (launch_params.params.seqlen_k == 128) { using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; run_fmha_fwd_loop(launch_params); @@ -13,5 +13,5 @@ void run_fmha_fwd_hdim64(Launch_params &launch_params) { using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; run_fmha_fwd_loop(launch_params); } - }); + })); } diff --git a/csrc/flash_attn/src/fmha_fwd_launch_template.h b/csrc/flash_attn/src/fmha_fwd_launch_template.h index 2876d3a..1b01375 100644 --- a/csrc/flash_attn/src/fmha_fwd_launch_template.h +++ b/csrc/flash_attn/src/fmha_fwd_launch_template.h @@ -8,7 +8,6 @@ #include #include "static_switch.h" -#include "fp16_switch.h" #include "fmha.h" #include "fmha_fprop_kernel_1xN.h" @@ -57,7 +56,7 @@ void run_fmha_fwd_loop(Launch_params &launch_params) { // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. // https://github.com/kokkos/kokkos-kernels/issues/349 // https://github.com/HazyResearch/flash-attention/issues/21 - BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ({ auto kernel = launch_params.params.is_causal ? (launch_params.return_softmax ? &fmha_fwd_loop_kernel @@ -88,5 +87,5 @@ void run_fmha_fwd_loop(Launch_params &launch_params) { kernel<<>>( launch_params.params); FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); + })); } diff --git a/csrc/flash_attn/src/fp16_switch.h b/csrc/flash_attn/src/fp16_switch.h deleted file mode 100644 index fed7cb9..0000000 --- a/csrc/flash_attn/src/fp16_switch.h +++ /dev/null @@ -1,27 +0,0 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -// modified from static_switch.h -// because MSVC cannot handle std::conditional with constexpr variable - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// FP16_SWITCH(flag, [&] { -/// some_function(...); -/// }); -/// ``` -#define FP16_SWITCH(COND, ...) \ - [&] { \ - if (COND) { \ - using elem_type = __nv_bfloat16; \ - return __VA_ARGS__(); \ - } else { \ - using elem_type = __half; \ - return __VA_ARGS__(); \ - } \ - }() \ No newline at end of file diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index 7920ac0..a77bae6 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -9,17 +9,27 @@ /// /// Usage: /// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { +/// BOOL_SWITCH(flag, BoolConst, ({ /// some_function(...); -/// }); +/// })); /// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() +/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro. +#define BOOL_SWITCH(COND, CONST_NAME, CODE) \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + CODE; \ + } else { \ + constexpr bool CONST_NAME = false; \ + CODE; \ + } + +// modified from BOOL_SWITCH +// because MSVC cannot handle std::conditional with constexpr variable +#define FP16_SWITCH(COND, CODE) \ + if (COND) { \ + using elem_type = __nv_bfloat16; \ + CODE; \ + } else { \ + using elem_type = __half; \ + CODE; \ + } \