Add in, macrosf for defining __grid_constant__ (#852)

This commit is contained in:
Driss Guessous 2024-03-15 00:48:54 -07:00 committed by GitHub
parent 2a15840f09
commit 4a73e903da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 21 deletions

View File

@ -46,7 +46,7 @@ void set_params_fprop(Flash_fwd_params &params,
bool seqlenq_ngroups_swapped=false) { bool seqlenq_ngroups_swapped=false) {
// Reset the parameters // Reset the parameters
memset(&params, 0, sizeof(params)); params = {};
params.is_bf16 = q.dtype() == torch::kBFloat16; params.is_bf16 = q.dtype() == torch::kBFloat16;

View File

@ -11,6 +11,40 @@
#include "flash_bwd_preprocess_kernel.h" #include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h" #include "flash_bwd_kernel.h"
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
template<bool Clear_dQaccum=true, typename Kernel_traits> template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params); flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
@ -21,17 +55,6 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
flash::clear_dKVaccum<Kernel_traits>(params); flash::clear_dKVaccum<Kernel_traits>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_kernel(__grid_constant__ const Flash_bwd_params params) {
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(__grid_constant__ const Flash_bwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
}
template<typename Kernel_traits> template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
flash::convert_dQ<Kernel_traits>(params, nsplits); flash::convert_dQ<Kernel_traits>(params, nsplits);

View File

@ -10,19 +10,40 @@
#include "flash.h" #include "flash.h"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax> // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false #define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params); flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
} }
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV> DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
__global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) { #if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params); flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
} }
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K> DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
__global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) {
static_assert(Log_max_splits >= 1); static_assert(Log_max_splits >= 1);
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params); flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
} }