Change constexpr int to constexpr static int
This commit is contained in:
parent
3a9fe7b0fa
commit
5a83425442
@ -198,7 +198,7 @@ includes QKV projection, output projection), see the MHA [implementation](https:
|
||||
|
||||
## Changelog
|
||||
|
||||
### 2.0
|
||||
### 2.0: Complete rewrite, 2x faster
|
||||
Upgrading from FlashAttention (1.x) to FlashAttention-2
|
||||
|
||||
These functions have been renamed:
|
||||
@ -214,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
|
||||
```python
|
||||
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
|
||||
```
|
||||
### 2.1
|
||||
### 2.1: Change behavior of causal flag
|
||||
|
||||
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
|
||||
bottom right corner of the attention matrix, instead of the top-left corner.
|
||||
@ -243,7 +243,7 @@ v2.1:
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
### 2.2
|
||||
### 2.2: Optimize for inference
|
||||
|
||||
Optimize for inference (iterative decoding) when query has very small sequence
|
||||
length (e.g., query sequence length = 1). The bottleneck here is to load KV
|
||||
@ -256,7 +256,7 @@ See the function `flash_attn_with_kvcache` with more features for inference
|
||||
Thanks to the xformers team, and in particular Daniel Haziza, for this
|
||||
collaboration.
|
||||
|
||||
### 2.3
|
||||
### 2.3: Local (i.e., sliding window) attention
|
||||
|
||||
Implement sliding window attention (i.e., local attention). Thanks to [Mistral
|
||||
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
|
||||
|
||||
@ -137,7 +137,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 32;
|
||||
constexpr static int Headdim = 32;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
@ -158,7 +158,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 64;
|
||||
constexpr static int Headdim = 64;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
@ -201,7 +201,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 96;
|
||||
constexpr static int Headdim = 96;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
@ -228,7 +228,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 128;
|
||||
constexpr static int Headdim = 128;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
@ -264,7 +264,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 160;
|
||||
constexpr static int Headdim = 160;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
@ -281,7 +281,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 192;
|
||||
constexpr static int Headdim = 192;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
@ -298,7 +298,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 224;
|
||||
constexpr static int Headdim = 224;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
});
|
||||
@ -306,7 +306,7 @@ void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 256;
|
||||
constexpr static int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
|
||||
@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// We want kBlockM to be as small as possible for more parallelism.
|
||||
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
|
||||
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
|
||||
constexpr int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
|
||||
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
|
||||
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
if (params.num_splits <= 2) {
|
||||
@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int kBlockM = 64; // Fixed for all head dimensions
|
||||
constexpr static int kBlockM = 64; // Fixed for all head dimensions
|
||||
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
||||
// and for headdim 192 with block size 64 x 128.
|
||||
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
|
||||
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
||||
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
||||
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 32;
|
||||
constexpr static int Headdim = 32;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 64;
|
||||
constexpr static int Headdim = 64;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 96;
|
||||
constexpr static int Headdim = 96;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 128;
|
||||
constexpr static int Headdim = 128;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 160;
|
||||
constexpr static int Headdim = 160;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 192;
|
||||
constexpr static int Headdim = 192;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 224;
|
||||
constexpr static int Headdim = 224;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 256;
|
||||
constexpr static int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_sm, max_smem_per_block;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user