Commit Graph

24 Commits

Author SHA1 Message Date
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Tri Dao
5d079fdd7a [Triton] Fix benchmark_causal, mention Triton version 2023-03-22 00:51:16 -07:00
Tri Dao
6b5f271c6d [Triton] Avoid einops repeat by using Tensor.expand 2022-12-14 14:48:41 -08:00
Tri Dao
b8ccd20098 [Triton] Fix variable name from qkv to kv (h/t FrankZijlstra) 2022-11-22 02:07:32 -08:00
Tri Dao
908a5b2244 Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty) 2022-11-07 08:58:16 -08:00
Tri Dao
7479757191 Fix pipelining bug in Triton bwd with bias_type=matrix 2022-11-06 11:50:35 -08:00
Tri Dao
557781933d Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
62025e1aff Fix more race condition in Triton bwd when there's bias 2022-11-04 12:53:09 -07:00
Tri Dao
ff78ea4123 Fix race condition in Triton bwd when there's bias 2022-11-04 11:20:27 -07:00
Tri Dao
86862cfd7b Implement attention bias for Triton version 2022-11-04 10:33:54 -07:00
Tri Dao
470010f59b Fix race condition for Triton bwd for headdim 48 and 96 2022-11-03 15:52:40 -07:00
Tri Dao
aacc10fbab Fix race condition in Triton bwd for non-po2 headdims 2022-11-02 07:32:54 -07:00
Tri Dao
1fb12afdfb Avoid memcpy in the Triton bwd 2022-11-01 15:06:45 -07:00
Tri Dao
731f154de3 Fix race conditions in the Triton bwd for headdim=64 2022-11-01 15:05:55 -07:00
Tri Dao
9b0bc97872 Fix race condition in Triton fwd 2022-10-31 14:34:57 -07:00
Tri Dao
215930bce3 Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd 2022-10-31 01:41:49 -07:00
Tri Dao
4f81aff46e Add debug_barrier for all headdims in Triton bwd 2022-10-31 01:25:02 -07:00
Tri Dao
bedcbd6a71 Disable some autotune configs that give wrong results in Triton bwd 2022-10-31 01:05:51 -07:00
Tri Dao
e78d509c64 [WIP] Support all head dimensions up to 128 in the Triton bwd
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
2022-10-31 00:46:22 -07:00
Tri Dao
008951f1d9 Support all head dimensions up to 128 in the Triton fwd 2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1 Support arbitrary seqlens (both q & k) in Triton bwd 2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355 Support arbitrary seqlen_k in Triton bwd 2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a Fix Triton fwd to support seqlen not multiples of 128 2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6 Implement FlashAttention in Triton 2022-10-30 18:09:11 -07:00