Change constexpr int to constexpr static int

This commit is contained in:
Tri Dao 2023-10-08 16:26:33 -07:00
parent 3a9fe7b0fa
commit 5a83425442
3 changed files with 23 additions and 23 deletions

View File

@ -198,7 +198,7 @@ includes QKV projection, output projection), see the MHA [implementation](https:
## Changelog
### 2.0
### 2.0: Complete rewrite, 2x faster
Upgrading from FlashAttention (1.x) to FlashAttention-2
These functions have been renamed:
@ -214,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
```
### 2.1
### 2.1: Change behavior of causal flag
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.
@ -243,7 +243,7 @@ v2.1:
1 1
If the row of the mask is all zero, the output will be zero.
### 2.2
### 2.2: Optimize for inference
Optimize for inference (iterative decoding) when query has very small sequence
length (e.g., query sequence length = 1). The bottleneck here is to load KV
@ -256,7 +256,7 @@ See the function `flash_attn_with_kvcache` with more features for inference
Thanks to the xformers team, and in particular Daniel Haziza, for this
collaboration.
### 2.3
### 2.3: Local (i.e., sliding window) attention
Implement sliding window attention (i.e., local attention). Thanks to [Mistral
AI](https://mistral.ai/) and in particular Timothée Lacroix for this

View File

@ -137,7 +137,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool con
template<typename T>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 32;
constexpr static int Headdim = 32;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
@ -158,7 +158,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const boo
template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 64;
constexpr static int Headdim = 64;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
@ -201,7 +201,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo
template<typename T>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 96;
constexpr static int Headdim = 96;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
@ -228,7 +228,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo
template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 128;
constexpr static int Headdim = 128;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
@ -264,7 +264,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
template<typename T>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 160;
constexpr static int Headdim = 160;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
@ -281,7 +281,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bo
template<typename T>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 192;
constexpr static int Headdim = 192;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
@ -298,7 +298,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bo
template<typename T>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 224;
constexpr static int Headdim = 224;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
});
@ -306,7 +306,7 @@ void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bo
template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 256;
constexpr static int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_block;

View File

@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, int Headdim>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int kBlockM = 64; // Fixed for all head dimensions
constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}
template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 32;
constexpr static int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 64;
constexpr static int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 96;
constexpr static int Headdim = 96;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 128;
constexpr static int Headdim = 128;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 160;
constexpr static int Headdim = 160;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 192;
constexpr static int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 224;
constexpr static int Headdim = 224;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 256;
constexpr static int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block;