Commit Graph

16 Commits

Author SHA1 Message Date
Tri Dao
ff78ea4123 Fix race condition in Triton bwd when there's bias 2022-11-04 11:20:27 -07:00
Tri Dao
86862cfd7b Implement attention bias for Triton version 2022-11-04 10:33:54 -07:00
Tri Dao
470010f59b Fix race condition for Triton bwd for headdim 48 and 96 2022-11-03 15:52:40 -07:00
Tri Dao
aacc10fbab Fix race condition in Triton bwd for non-po2 headdims 2022-11-02 07:32:54 -07:00
Tri Dao
1fb12afdfb Avoid memcpy in the Triton bwd 2022-11-01 15:06:45 -07:00
Tri Dao
731f154de3 Fix race conditions in the Triton bwd for headdim=64 2022-11-01 15:05:55 -07:00
Tri Dao
9b0bc97872 Fix race condition in Triton fwd 2022-10-31 14:34:57 -07:00
Tri Dao
215930bce3 Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd 2022-10-31 01:41:49 -07:00
Tri Dao
4f81aff46e Add debug_barrier for all headdims in Triton bwd 2022-10-31 01:25:02 -07:00
Tri Dao
bedcbd6a71 Disable some autotune configs that give wrong results in Triton bwd 2022-10-31 01:05:51 -07:00
Tri Dao
e78d509c64 [WIP] Support all head dimensions up to 128 in the Triton bwd
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
2022-10-31 00:46:22 -07:00
Tri Dao
008951f1d9 Support all head dimensions up to 128 in the Triton fwd 2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1 Support arbitrary seqlens (both q & k) in Triton bwd 2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355 Support arbitrary seqlen_k in Triton bwd 2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a Fix Triton fwd to support seqlen not multiples of 128 2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6 Implement FlashAttention in Triton 2022-10-30 18:09:11 -07:00