Add in, macrosf for defining __grid_constant__ (#852)
This commit is contained in:
parent
2a15840f09
commit
4a73e903da
@ -46,7 +46,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
bool seqlenq_ngroups_swapped=false) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
params = {};
|
||||
|
||||
params.is_bf16 = q.dtype() == torch::kBFloat16;
|
||||
|
||||
|
||||
@ -11,6 +11,40 @@
|
||||
#include "flash_bwd_preprocess_kernel.h"
|
||||
#include "flash_bwd_kernel.h"
|
||||
|
||||
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define ARCH_SUPPORTS_FLASH
|
||||
#define KERNEL_PARAM_MODIFIER __grid_constant__
|
||||
#else
|
||||
#define KERNEL_PARAM_MODIFIER
|
||||
#endif
|
||||
|
||||
// Define a macro for unsupported architecture handling to centralize the error message
|
||||
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
|
||||
|
||||
// Use a macro to clean up kernel definitions
|
||||
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
|
||||
template<typename Kernel_traits, __VA_ARGS__> \
|
||||
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
|
||||
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
template<bool Clear_dQaccum=true, typename Kernel_traits>
|
||||
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
|
||||
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
|
||||
@ -21,17 +55,6 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
|
||||
flash::clear_dKVaccum<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_kernel(__grid_constant__ const Flash_bwd_params params) {
|
||||
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(__grid_constant__ const Flash_bwd_params params) {
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
|
||||
flash::convert_dQ<Kernel_traits>(params, nsplits);
|
||||
|
||||
@ -10,19 +10,40 @@
|
||||
#include "flash.h"
|
||||
#include "flash_fwd_kernel.h"
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
|
||||
__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) {
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define ARCH_SUPPORTS_FLASH
|
||||
#define KERNEL_PARAM_MODIFIER __grid_constant__
|
||||
#else
|
||||
#define KERNEL_PARAM_MODIFIER
|
||||
#endif
|
||||
|
||||
// Define a macro for unsupported architecture handling to centralize the error message
|
||||
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
|
||||
|
||||
// Use a macro to clean up kernel definitions
|
||||
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
|
||||
template<typename Kernel_traits, __VA_ARGS__> \
|
||||
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
|
||||
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
static_assert(!(Is_causal && Is_local)); // Enforce constraints
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
|
||||
__global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) {
|
||||
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
|
||||
__global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) {
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
|
||||
static_assert(Log_max_splits >= 1);
|
||||
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user