Commit Graph

46 Commits

Author SHA1 Message Date
Tri Dao
2e33fc8e36 Add GPT and ViT models 2022-11-13 22:30:23 -08:00
Tri Dao
d4b320b31f Add MLP, MHA, Block, Embedding modules 2022-11-13 22:06:44 -08:00
Tri Dao
fa6d1ce44f Add fused_dense and dropout_add_layernorm CUDA extensions 2022-11-13 21:59:20 -08:00
Tri Dao
343492ec30 Make nccl operations async in CrossEntropyLossParallel 2022-11-13 17:27:26 -08:00
Tri Dao
7c9953815a Add fused cross entropy loss 2022-11-12 21:58:41 -08:00
Tri Dao
55797f32c9 Remove RotaryEmbedding from FlashAttention module
To avoid import error if one doesn't have rotary_emb installed
2022-11-10 11:54:36 -08:00
Tri Dao
908a5b2244 Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty) 2022-11-07 08:58:16 -08:00
Tri Dao
7479757191 Fix pipelining bug in Triton bwd with bias_type=matrix 2022-11-06 11:50:35 -08:00
Tri Dao
557781933d Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
ca81f32e04 Implement rotary embedding in CUDA 2022-11-04 22:42:01 -07:00
Tri Dao
62025e1aff Fix more race condition in Triton bwd when there's bias 2022-11-04 12:53:09 -07:00
Tri Dao
ff78ea4123 Fix race condition in Triton bwd when there's bias 2022-11-04 11:20:27 -07:00
Tri Dao
86862cfd7b Implement attention bias for Triton version 2022-11-04 10:33:54 -07:00
Tri Dao
470010f59b Fix race condition for Triton bwd for headdim 48 and 96 2022-11-03 15:52:40 -07:00
Tri Dao
aacc10fbab Fix race condition in Triton bwd for non-po2 headdims 2022-11-02 07:32:54 -07:00
Tri Dao
1fb12afdfb Avoid memcpy in the Triton bwd 2022-11-01 15:06:45 -07:00
Tri Dao
731f154de3 Fix race conditions in the Triton bwd for headdim=64 2022-11-01 15:05:55 -07:00
Tri Dao
9b0bc97872 Fix race condition in Triton fwd 2022-10-31 14:34:57 -07:00
Tri Dao
215930bce3 Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd 2022-10-31 01:41:49 -07:00
Tri Dao
4f81aff46e Add debug_barrier for all headdims in Triton bwd 2022-10-31 01:25:02 -07:00
Tri Dao
bedcbd6a71 Disable some autotune configs that give wrong results in Triton bwd 2022-10-31 01:05:51 -07:00
Tri Dao
e78d509c64 [WIP] Support all head dimensions up to 128 in the Triton bwd
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
2022-10-31 00:46:22 -07:00
Tri Dao
008951f1d9 Support all head dimensions up to 128 in the Triton fwd 2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1 Support arbitrary seqlens (both q & k) in Triton bwd 2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355 Support arbitrary seqlen_k in Triton bwd 2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a Fix Triton fwd to support seqlen not multiples of 128 2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6 Implement FlashAttention in Triton 2022-10-30 18:09:11 -07:00
Tri Dao
46fd2a20b2 Support all head dims that are multiples of 8, up to 128 2022-10-24 16:04:21 -07:00
Tri Dao
ed553e9238 Add Megatron attention implementation for benchmarking 2022-10-23 23:04:16 -07:00
Tri Dao
50ca23488d Add Triton implementation for benchmarking 2022-10-23 17:25:56 -07:00
Tri Dao
fb88e5e4b3 Move benchmark utils, support AMP 2022-10-23 12:50:00 -07:00
Tri Dao
a5a8806d1a Split bwd on the seqlen_q dimension 2022-10-23 11:35:15 -07:00
Tri Dao
a44f48df5a Split fwd on the seqlen_q dimension 2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6 Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1b9facacc3 Fix QKV interface to allocate output in Python 2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848 Implement attention kernel that splits the batch into two 2022-10-13 20:49:02 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
13403e8115 Relax assert to allow both bf16 and fp16 2022-09-11 12:09:43 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025 Add back need_weights in FlashMHA 2022-08-09 10:14:10 -07:00
Tri Dao
6cc7342575 Support index_first_axis with more than 2 dimensions 2022-08-05 09:48:16 -07:00
Tri Dao
713ea302d7 Allow headdim 128 in FlashMHA interface 2022-08-05 09:47:22 -07:00
Tri Dao
a5559a0e75 Do P * dP (pointwise) in the bwd in fp32 instead of fp16 2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Gustaf
af4a9ce024 Add missing __init__.py 2022-07-03 02:04:55 -04:00
Tri Dao
5a61cb7729 Rename src -> flash_attn 2022-06-01 18:50:26 -07:00