Tri Dao
|
46fd2a20b2
|
Support all head dims that are multiples of 8, up to 128
|
2022-10-24 16:04:21 -07:00 |
|
Tri Dao
|
a5a8806d1a
|
Split bwd on the seqlen_q dimension
|
2022-10-23 11:35:15 -07:00 |
|
Tri Dao
|
a44f48df5a
|
Split fwd on the seqlen_q dimension
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
1aa6d7d9b6
|
Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
1b9facacc3
|
Fix QKV interface to allocate output in Python
|
2022-10-14 03:33:41 -07:00 |
|
Tri Dao
|
5badfb7848
|
Implement attention kernel that splits the batch into two
|
2022-10-13 20:49:02 -07:00 |
|
Tri Dao
|
a5559a0e75
|
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
|
2022-07-03 17:52:05 -07:00 |
|
Tri Dao
|
6c3a8c65af
|
Implement cross attention
|
2022-07-03 17:48:12 -07:00 |
|
Tri Dao
|
5a61cb7729
|
Rename src -> flash_attn
|
2022-06-01 18:50:26 -07:00 |
|