Tri Dao
|
9b0bc97872
|
Fix race condition in Triton fwd
|
2022-10-31 14:34:57 -07:00 |
|
Tri Dao
|
4f81aff46e
|
Add debug_barrier for all headdims in Triton bwd
|
2022-10-31 01:25:02 -07:00 |
|
Tri Dao
|
e78d509c64
|
[WIP] Support all head dimensions up to 128 in the Triton bwd
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
|
2022-10-31 00:46:22 -07:00 |
|
Tri Dao
|
008951f1d9
|
Support all head dimensions up to 128 in the Triton fwd
|
2022-10-30 22:10:48 -07:00 |
|
Tri Dao
|
b910bf14c1
|
Support arbitrary seqlens (both q & k) in Triton bwd
|
2022-10-30 21:50:53 -07:00 |
|
Tri Dao
|
dc55469355
|
Support arbitrary seqlen_k in Triton bwd
|
2022-10-30 21:26:26 -07:00 |
|
Tri Dao
|
d11341fd1a
|
Fix Triton fwd to support seqlen not multiples of 128
|
2022-10-30 19:05:47 -07:00 |
|
Tri Dao
|
b0c0db81f6
|
Implement FlashAttention in Triton
|
2022-10-30 18:09:11 -07:00 |
|
Tri Dao
|
46fd2a20b2
|
Support all head dims that are multiples of 8, up to 128
|
2022-10-24 16:04:21 -07:00 |
|
Tri Dao
|
a5a8806d1a
|
Split bwd on the seqlen_q dimension
|
2022-10-23 11:35:15 -07:00 |
|
Tri Dao
|
1aa6d7d9b6
|
Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
52fb4b729b
|
Fix #54: set device for multi-GPU case
|
2022-10-16 12:51:26 -07:00 |
|
Tri Dao
|
5badfb7848
|
Implement attention kernel that splits the batch into two
|
2022-10-13 20:49:02 -07:00 |
|
Tri Dao
|
0c01568daf
|
Only run backward test for d=128 on A100
|
2022-10-04 18:06:08 -07:00 |
|
Tri Dao
|
2ed471ecc4
|
Add tests for numerical error
|
2022-07-22 17:54:09 -04:00 |
|