Tri Dao
008951f1d9
Support all head dimensions up to 128 in the Triton fwd
2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1
Support arbitrary seqlens (both q & k) in Triton bwd
2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355
Support arbitrary seqlen_k in Triton bwd
2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a
Fix Triton fwd to support seqlen not multiples of 128
2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6
Implement FlashAttention in Triton
2022-10-30 18:09:11 -07:00
Tri Dao
46fd2a20b2
Support all head dims that are multiples of 8, up to 128
2022-10-24 16:04:21 -07:00
Tri Dao
ed553e9238
Add Megatron attention implementation for benchmarking
2022-10-23 23:04:16 -07:00
Tri Dao
50ca23488d
Add Triton implementation for benchmarking
2022-10-23 17:25:56 -07:00
Tri Dao
fb88e5e4b3
Move benchmark utils, support AMP
2022-10-23 12:50:00 -07:00
Tri Dao
a5a8806d1a
Split bwd on the seqlen_q dimension
2022-10-23 11:35:15 -07:00
Tri Dao
a44f48df5a
Split fwd on the seqlen_q dimension
2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6
Rework dropout to decouple forward and backward
...
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1b9facacc3
Fix QKV interface to allocate output in Python
2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848
Implement attention kernel that splits the batch into two
2022-10-13 20:49:02 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
...
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
13403e8115
Relax assert to allow both bf16 and fp16
2022-09-11 12:09:43 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
...
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025
Add back need_weights in FlashMHA
2022-08-09 10:14:10 -07:00
Tri Dao
6cc7342575
Support index_first_axis with more than 2 dimensions
2022-08-05 09:48:16 -07:00
Tri Dao
713ea302d7
Allow headdim 128 in FlashMHA interface
2022-08-05 09:47:22 -07:00
Tri Dao
a5559a0e75
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af
Implement cross attention
2022-07-03 17:48:12 -07:00
Gustaf
af4a9ce024
Add missing __init__.py
2022-07-03 02:04:55 -04:00
Tri Dao
5a61cb7729
Rename src -> flash_attn
2022-06-01 18:50:26 -07:00