Tri Dao
|
b0c0db81f6
|
Implement FlashAttention in Triton
|
2022-10-30 18:09:11 -07:00 |
|
Tri Dao
|
46fd2a20b2
|
Support all head dims that are multiples of 8, up to 128
|
2022-10-24 16:04:21 -07:00 |
|
Tri Dao
|
a5a8806d1a
|
Split bwd on the seqlen_q dimension
|
2022-10-23 11:35:15 -07:00 |
|
Tri Dao
|
1aa6d7d9b6
|
Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
52fb4b729b
|
Fix #54: set device for multi-GPU case
|
2022-10-16 12:51:26 -07:00 |
|
Tri Dao
|
5badfb7848
|
Implement attention kernel that splits the batch into two
|
2022-10-13 20:49:02 -07:00 |
|
Tri Dao
|
0c01568daf
|
Only run backward test for d=128 on A100
|
2022-10-04 18:06:08 -07:00 |
|
Tri Dao
|
2ed471ecc4
|
Add tests for numerical error
|
2022-07-22 17:54:09 -04:00 |
|