diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index 403ce92..2379a37 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -2,6 +2,7 @@ */ #include "static_switch.h" +#include "fp16_switch.h" #include "fmha.h" #include "fmha_dgrad_kernel_1xN_loop.h" @@ -52,8 +53,8 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ } void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.is_bf16, IsBf16Const, [&] { - using elem_type = std::conditional::type; + // work around for MSVC issue + FP16_SWITCH(params.is_bf16, [&] { auto dprops = at::cuda::getCurrentDeviceProperties(); if (params.d == 16) { if( params.seqlen_k == 128 ) { diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 8841d6b..32e793c 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -29,6 +29,7 @@ #include #include "static_switch.h" +#include "fp16_switch.h" #include "fmha.h" #include "fmha_fprop_kernel_1xN.h" @@ -83,8 +84,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, void run_fmha_fp16_sm80(Launch_params &launch_params, const bool configure) { - BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] { - using elem_type = std::conditional::type; + FP16_SWITCH(launch_params.params.is_bf16, [&] { auto dprops = at::cuda::getCurrentDeviceProperties(); if (launch_params.params.d == 16) { if( launch_params.params.seqlen_k == 128 ) { diff --git a/csrc/flash_attn/src/fp16_switch.h b/csrc/flash_attn/src/fp16_switch.h new file mode 100644 index 0000000..fed7cb9 --- /dev/null +++ b/csrc/flash_attn/src/fp16_switch.h @@ -0,0 +1,27 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +// modified from static_switch.h +// because MSVC cannot handle std::conditional with constexpr variable + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// FP16_SWITCH(flag, [&] { +/// some_function(...); +/// }); +/// ``` +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = __half; \ + return __VA_ARGS__(); \ + } \ + }() \ No newline at end of file diff --git a/setup.py b/setup.py index eabcf06..ec1415a 100644 --- a/setup.py +++ b/setup.py @@ -125,10 +125,11 @@ ext_modules.append( "csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu", ], extra_compile_args={ - "cxx": ["-O3"] + generator_flag, + "cxx": ["-O3", "-std=c++17"] + generator_flag, "nvcc": append_nvcc_threads( [ "-O3", + "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "--expt-relaxed-constexpr",