Tri Dao
1feb94265c
[ViT] Use dropout_add_ln for the 1st layer norm
2022-11-23 12:48:56 -08:00
Tri Dao
b8ccd20098
[Triton] Fix variable name from qkv to kv (h/t FrankZijlstra)
2022-11-22 02:07:32 -08:00
Tri Dao
054816177e
Bump version to 0.2.1
2022-11-20 22:35:59 -08:00
Tri Dao
0fa5c0d7ef
Add PatchEmbed
2022-11-17 16:56:06 -08:00
Tri Dao
ece539abd6
Add __init__.py files to subdirectories for installation
2022-11-17 16:55:44 -08:00
Tri Dao
71f674ae23
[Rotary] Customize base, support seqlen_offset
2022-11-17 11:43:36 -08:00
Tri Dao
2e33fc8e36
Add GPT and ViT models
2022-11-13 22:30:23 -08:00
Tri Dao
d4b320b31f
Add MLP, MHA, Block, Embedding modules
2022-11-13 22:06:44 -08:00
Tri Dao
fa6d1ce44f
Add fused_dense and dropout_add_layernorm CUDA extensions
2022-11-13 21:59:20 -08:00
Tri Dao
343492ec30
Make nccl operations async in CrossEntropyLossParallel
2022-11-13 17:27:26 -08:00
Tri Dao
7c9953815a
Add fused cross entropy loss
2022-11-12 21:58:41 -08:00
Tri Dao
55797f32c9
Remove RotaryEmbedding from FlashAttention module
...
To avoid import error if one doesn't have rotary_emb installed
2022-11-10 11:54:36 -08:00
Tri Dao
908a5b2244
Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty)
2022-11-07 08:58:16 -08:00
Tri Dao
7479757191
Fix pipelining bug in Triton bwd with bias_type=matrix
2022-11-06 11:50:35 -08:00
Tri Dao
557781933d
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
...
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
ca81f32e04
Implement rotary embedding in CUDA
2022-11-04 22:42:01 -07:00
Tri Dao
62025e1aff
Fix more race condition in Triton bwd when there's bias
2022-11-04 12:53:09 -07:00
Tri Dao
ff78ea4123
Fix race condition in Triton bwd when there's bias
2022-11-04 11:20:27 -07:00
Tri Dao
86862cfd7b
Implement attention bias for Triton version
2022-11-04 10:33:54 -07:00
Tri Dao
470010f59b
Fix race condition for Triton bwd for headdim 48 and 96
2022-11-03 15:52:40 -07:00
Tri Dao
aacc10fbab
Fix race condition in Triton bwd for non-po2 headdims
2022-11-02 07:32:54 -07:00
Tri Dao
1fb12afdfb
Avoid memcpy in the Triton bwd
2022-11-01 15:06:45 -07:00
Tri Dao
731f154de3
Fix race conditions in the Triton bwd for headdim=64
2022-11-01 15:05:55 -07:00
Tri Dao
9b0bc97872
Fix race condition in Triton fwd
2022-10-31 14:34:57 -07:00
Tri Dao
215930bce3
Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd
2022-10-31 01:41:49 -07:00
Tri Dao
4f81aff46e
Add debug_barrier for all headdims in Triton bwd
2022-10-31 01:25:02 -07:00
Tri Dao
bedcbd6a71
Disable some autotune configs that give wrong results in Triton bwd
2022-10-31 01:05:51 -07:00
Tri Dao
e78d509c64
[WIP] Support all head dimensions up to 128 in the Triton bwd
...
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
2022-10-31 00:46:22 -07:00
Tri Dao
008951f1d9
Support all head dimensions up to 128 in the Triton fwd
2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1
Support arbitrary seqlens (both q & k) in Triton bwd
2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355
Support arbitrary seqlen_k in Triton bwd
2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a
Fix Triton fwd to support seqlen not multiples of 128
2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6
Implement FlashAttention in Triton
2022-10-30 18:09:11 -07:00
Tri Dao
46fd2a20b2
Support all head dims that are multiples of 8, up to 128
2022-10-24 16:04:21 -07:00
Tri Dao
ed553e9238
Add Megatron attention implementation for benchmarking
2022-10-23 23:04:16 -07:00
Tri Dao
50ca23488d
Add Triton implementation for benchmarking
2022-10-23 17:25:56 -07:00
Tri Dao
fb88e5e4b3
Move benchmark utils, support AMP
2022-10-23 12:50:00 -07:00
Tri Dao
a5a8806d1a
Split bwd on the seqlen_q dimension
2022-10-23 11:35:15 -07:00
Tri Dao
a44f48df5a
Split fwd on the seqlen_q dimension
2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6
Rework dropout to decouple forward and backward
...
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1b9facacc3
Fix QKV interface to allocate output in Python
2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848
Implement attention kernel that splits the batch into two
2022-10-13 20:49:02 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
...
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
13403e8115
Relax assert to allow both bf16 and fp16
2022-09-11 12:09:43 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
...
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025
Add back need_weights in FlashMHA
2022-08-09 10:14:10 -07:00
Tri Dao
6cc7342575
Support index_first_axis with more than 2 dimensions
2022-08-05 09:48:16 -07:00
Tri Dao
713ea302d7
Allow headdim 128 in FlashMHA interface
2022-08-05 09:47:22 -07:00
Tri Dao
a5559a0e75
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af
Implement cross attention
2022-07-03 17:48:12 -07:00