Commit Graph

13 Commits

Author SHA1 Message Date
Tri Dao
1aa6d7d9b6 Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1b9facacc3 Fix QKV interface to allocate output in Python 2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848 Implement attention kernel that splits the batch into two 2022-10-13 20:49:02 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
13403e8115 Relax assert to allow both bf16 and fp16 2022-09-11 12:09:43 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025 Add back need_weights in FlashMHA 2022-08-09 10:14:10 -07:00
Tri Dao
6cc7342575 Support index_first_axis with more than 2 dimensions 2022-08-05 09:48:16 -07:00
Tri Dao
713ea302d7 Allow headdim 128 in FlashMHA interface 2022-08-05 09:47:22 -07:00
Tri Dao
a5559a0e75 Do P * dP (pointwise) in the bwd in fp32 instead of fp16 2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Gustaf
af4a9ce024 Add missing __init__.py 2022-07-03 02:04:55 -04:00
Tri Dao
5a61cb7729 Rename src -> flash_attn 2022-06-01 18:50:26 -07:00