diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 96f0cff..001acac 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -46,7 +46,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, bool seqlenq_ngroups_swapped=false) { // Reset the parameters - memset(¶ms, 0, sizeof(params)); + params = {}; params.is_bf16 = q.dtype() == torch::kBFloat16; diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 834fc4c..fd81c88 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -11,6 +11,40 @@ #include "flash_bwd_preprocess_kernel.h" #include "flash_bwd_kernel.h" +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) + +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) { + #if defined(ARCH_SUPPORTS_FLASH) + flash::compute_dq_dk_dv(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) { + #if defined(ARCH_SUPPORTS_FLASH) + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + flash::compute_dq_dk_dv_seqk_parallel(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + + template __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { flash::compute_dot_do_o(params); @@ -21,17 +55,6 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) { flash::clear_dKVaccum(params); } -template -__global__ void flash_bwd_dq_dk_dv_loop_kernel(__grid_constant__ const Flash_bwd_params params) { - flash::compute_dq_dk_dv(params); -} - -template -__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(__grid_constant__ const Flash_bwd_params params) { - static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - flash::compute_dq_dk_dv_seqk_parallel(params); -} - template __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { flash::convert_dQ(params, nsplits); diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 1d30d9e..fa6a6f6 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -10,19 +10,40 @@ #include "flash.h" #include "flash_fwd_kernel.h" -template -__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) { - static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - flash::compute_attn(params); +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { + #if defined(ARCH_SUPPORTS_FLASH) + static_assert(!(Is_causal && Is_local)); // Enforce constraints + flash::compute_attn(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif } -template -__global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) { - flash::compute_attn_splitkv(params); +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) { + #if defined(ARCH_SUPPORTS_FLASH) + flash::compute_attn_splitkv(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif } -template -__global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { static_assert(Log_max_splits >= 1); flash::combine_attn_seqk_parallel(params); }