Tri Dao
|
19d1261025
|
Add back need_weights in FlashMHA
|
2022-08-09 10:14:10 -07:00 |
|
Tri Dao
|
6cc7342575
|
Support index_first_axis with more than 2 dimensions
|
2022-08-05 09:48:16 -07:00 |
|
Tri Dao
|
713ea302d7
|
Allow headdim 128 in FlashMHA interface
|
2022-08-05 09:47:22 -07:00 |
|
Tri Dao
|
a5559a0e75
|
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
|
2022-07-03 17:52:05 -07:00 |
|
Tri Dao
|
6c3a8c65af
|
Implement cross attention
|
2022-07-03 17:48:12 -07:00 |
|
Gustaf
|
af4a9ce024
|
Add missing __init__.py
|
2022-07-03 02:04:55 -04:00 |
|
Tri Dao
|
5a61cb7729
|
Rename src -> flash_attn
|
2022-06-01 18:50:26 -07:00 |
|