Tri Dao
a8fec99a9a
Skip flash_attn_split test
2022-11-13 12:27:48 -08:00
Tri Dao
9d3116addf
Don't enforce bitwise consistency for dq in race condition test
...
Since we could be parallelizing over seqlen_k
2022-11-13 12:21:51 -08:00
Tri Dao
7c9953815a
Add fused cross entropy loss
2022-11-12 21:58:41 -08:00
Tri Dao
55797f32c9
Remove RotaryEmbedding from FlashAttention module
...
To avoid import error if one doesn't have rotary_emb installed
2022-11-10 11:54:36 -08:00
Tri Dao
6998e0ecdb
Fix out-of-bound memory read
2022-11-09 09:34:14 -08:00
Tri Dao
908a5b2244
Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty)
2022-11-07 08:58:16 -08:00
Tri Dao
7479757191
Fix pipelining bug in Triton bwd with bias_type=matrix
2022-11-06 11:50:35 -08:00
Tri Dao
557781933d
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
...
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
ca81f32e04
Implement rotary embedding in CUDA
2022-11-04 22:42:01 -07:00
Tri Dao
62025e1aff
Fix more race condition in Triton bwd when there's bias
2022-11-04 12:53:09 -07:00
Tri Dao
ff78ea4123
Fix race condition in Triton bwd when there's bias
2022-11-04 11:20:27 -07:00
Tri Dao
86862cfd7b
Implement attention bias for Triton version
2022-11-04 10:33:54 -07:00
Tri Dao
470010f59b
Fix race condition for Triton bwd for headdim 48 and 96
2022-11-03 15:52:40 -07:00
Tri Dao
aacc10fbab
Fix race condition in Triton bwd for non-po2 headdims
2022-11-02 07:32:54 -07:00
Tri Dao
1fb12afdfb
Avoid memcpy in the Triton bwd
2022-11-01 15:06:45 -07:00
Tri Dao
731f154de3
Fix race conditions in the Triton bwd for headdim=64
2022-11-01 15:05:55 -07:00
Tri Dao
9b0bc97872
Fix race condition in Triton fwd
2022-10-31 14:34:57 -07:00
Tri Dao
215930bce3
Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd
2022-10-31 01:41:49 -07:00
Tri Dao
4f81aff46e
Add debug_barrier for all headdims in Triton bwd
2022-10-31 01:25:02 -07:00
Tri Dao
bedcbd6a71
Disable some autotune configs that give wrong results in Triton bwd
2022-10-31 01:05:51 -07:00
Tri Dao
e78d509c64
[WIP] Support all head dimensions up to 128 in the Triton bwd
...
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
2022-10-31 00:46:22 -07:00
Tri Dao
008951f1d9
Support all head dimensions up to 128 in the Triton fwd
2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1
Support arbitrary seqlens (both q & k) in Triton bwd
2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355
Support arbitrary seqlen_k in Triton bwd
2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a
Fix Triton fwd to support seqlen not multiples of 128
2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6
Implement FlashAttention in Triton
2022-10-30 18:09:11 -07:00
Tri Dao
c422fee377
Get rid of o_rows_are_valid since we don't have headdim=16 anymore
2022-10-24 17:29:36 -07:00
Tri Dao
46fd2a20b2
Support all head dims that are multiples of 8, up to 128
2022-10-24 16:04:21 -07:00
Tri Dao
97e13de2b4
Cast q.get_device() to char to avoid compiler warning (narrowing)
2022-10-24 15:59:49 -07:00
Tri Dao
ed553e9238
Add Megatron attention implementation for benchmarking
2022-10-23 23:04:16 -07:00
Tri Dao
50ca23488d
Add Triton implementation for benchmarking
2022-10-23 17:25:56 -07:00
Tri Dao
9e92a1f2d2
Attempt to use atomicCAS to replace atomicAdd(bfloat16)
2022-10-23 16:22:43 -07:00
Tri Dao
6731855b1f
Merge pull request #61 from robotcator/workflow
...
build wheel and upload to release
2022-10-23 12:52:51 -07:00
Tri Dao
fb88e5e4b3
Move benchmark utils, support AMP
2022-10-23 12:50:00 -07:00
Tri Dao
a5a8806d1a
Split bwd on the seqlen_q dimension
2022-10-23 11:35:15 -07:00
Tri Dao
871db47941
Don't need to run configure for the forward pass
2022-10-21 18:22:27 -07:00
Tri Dao
7fc39832e2
Use block_size=128 for headdim=128 on SM80
...
Previously we were using block_size=256.
2022-10-21 13:19:54 -07:00
Tri Dao
a44f48df5a
Split fwd on the seqlen_q dimension
2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6
Rework dropout to decouple forward and backward
...
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1d0b41be3b
Merge pull request #60 from 201419/patch-1
...
fix typo in function mha_fwd
2022-10-17 09:38:48 -07:00
robotcator
35d589fa81
Merge branch 'main' of github.com:robotcator/flash-attention into workflow
2022-10-17 17:41:37 +08:00
robotcator
10d0745966
using tag trigger rather than push trigger
2022-10-17 17:39:08 +08:00
YangShu
ff07250e8f
fix typo in function mha_fwd
...
as title.
2022-10-17 16:13:47 +08:00
Tri Dao
52fb4b729b
Fix #54 : set device for multi-GPU case
2022-10-16 12:51:26 -07:00
Tri Dao
1b9facacc3
Fix QKV interface to allocate output in Python
2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848
Implement attention kernel that splits the batch into two
2022-10-13 20:49:02 -07:00
Tri Dao
f515c77f25
Merge pull request #53 from robotcator/workflow
...
build wheel workflow
2022-10-09 22:26:22 -07:00
Tri Dao
8dd52b0788
Merge pull request #55 from ajfadam/main
...
remove numpy dependency
2022-10-06 10:29:38 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
...
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
88dc2040a0
Merge pull request #52 from bob80333/main
...
Make flash attention compile on Windows.
2022-10-04 21:25:37 -07:00