Commit Graph

10 Commits

Author SHA1 Message Date
Vik Paruchuri
3165398074 Remove unused kwargs in flashattention 2023-03-15 10:36:19 -07:00
Kiarash Jamali
41cb909741
Change default dropout value in documentation
Documentation says default is 0.1, but the code has attention_dropout default at 0.0
2023-01-13 10:50:07 +00:00
Tri Dao
5fb6df0e04 Implement BERT 2022-12-18 21:47:27 -08:00
Tri Dao
55797f32c9 Remove RotaryEmbedding from FlashAttention module
To avoid import error if one doesn't have rotary_emb installed
2022-11-10 11:54:36 -08:00
Tri Dao
13403e8115 Relax assert to allow both bf16 and fp16 2022-09-11 12:09:43 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025 Add back need_weights in FlashMHA 2022-08-09 10:14:10 -07:00
Tri Dao
713ea302d7 Allow headdim 128 in FlashMHA interface 2022-08-05 09:47:22 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Tri Dao
5a61cb7729 Rename src -> flash_attn 2022-06-01 18:50:26 -07:00