Commit Graph

124 Commits

Author SHA1 Message Date
Phil Wang
b0eac3297f allow for uploading to pypi 2022-11-15 13:26:55 -08:00
Tri Dao
2e33fc8e36 Add GPT and ViT models 2022-11-13 22:30:23 -08:00
Tri Dao
d4b320b31f Add MLP, MHA, Block, Embedding modules 2022-11-13 22:06:44 -08:00
Tri Dao
fa6d1ce44f Add fused_dense and dropout_add_layernorm CUDA extensions 2022-11-13 21:59:20 -08:00
Tri Dao
b92f2c3b67 Link to Colossal-AI's stable diffusion in usage.md 2022-11-13 20:49:05 -08:00
Tri Dao
343492ec30 Make nccl operations async in CrossEntropyLossParallel 2022-11-13 17:27:26 -08:00
Tri Dao
3dda4f76de Update README 2022-11-13 16:52:40 -08:00
Tri Dao
79160a69a9 Add a page on where FlashAttention is being used 2022-11-13 16:40:18 -08:00
Tri Dao
a8fec99a9a Skip flash_attn_split test 2022-11-13 12:27:48 -08:00
Tri Dao
9d3116addf Don't enforce bitwise consistency for dq in race condition test
Since we could be parallelizing over seqlen_k
2022-11-13 12:21:51 -08:00
Tri Dao
7c9953815a Add fused cross entropy loss 2022-11-12 21:58:41 -08:00
Tri Dao
55797f32c9 Remove RotaryEmbedding from FlashAttention module
To avoid import error if one doesn't have rotary_emb installed
2022-11-10 11:54:36 -08:00
Tri Dao
6998e0ecdb Fix out-of-bound memory read 2022-11-09 09:34:14 -08:00
Tri Dao
908a5b2244 Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty) 2022-11-07 08:58:16 -08:00
Tri Dao
7479757191 Fix pipelining bug in Triton bwd with bias_type=matrix 2022-11-06 11:50:35 -08:00
Tri Dao
557781933d Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
ca81f32e04 Implement rotary embedding in CUDA 2022-11-04 22:42:01 -07:00
Tri Dao
62025e1aff Fix more race condition in Triton bwd when there's bias 2022-11-04 12:53:09 -07:00
Tri Dao
ff78ea4123 Fix race condition in Triton bwd when there's bias 2022-11-04 11:20:27 -07:00
Tri Dao
86862cfd7b Implement attention bias for Triton version 2022-11-04 10:33:54 -07:00
Tri Dao
470010f59b Fix race condition for Triton bwd for headdim 48 and 96 2022-11-03 15:52:40 -07:00
Tri Dao
aacc10fbab Fix race condition in Triton bwd for non-po2 headdims 2022-11-02 07:32:54 -07:00
Tri Dao
1fb12afdfb Avoid memcpy in the Triton bwd 2022-11-01 15:06:45 -07:00
Tri Dao
731f154de3 Fix race conditions in the Triton bwd for headdim=64 2022-11-01 15:05:55 -07:00
Tri Dao
9b0bc97872 Fix race condition in Triton fwd 2022-10-31 14:34:57 -07:00
Tri Dao
215930bce3 Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd 2022-10-31 01:41:49 -07:00
Tri Dao
4f81aff46e Add debug_barrier for all headdims in Triton bwd 2022-10-31 01:25:02 -07:00
Tri Dao
bedcbd6a71 Disable some autotune configs that give wrong results in Triton bwd 2022-10-31 01:05:51 -07:00
Tri Dao
e78d509c64 [WIP] Support all head dimensions up to 128 in the Triton bwd
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
2022-10-31 00:46:22 -07:00
Tri Dao
008951f1d9 Support all head dimensions up to 128 in the Triton fwd 2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1 Support arbitrary seqlens (both q & k) in Triton bwd 2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355 Support arbitrary seqlen_k in Triton bwd 2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a Fix Triton fwd to support seqlen not multiples of 128 2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6 Implement FlashAttention in Triton 2022-10-30 18:09:11 -07:00
Tri Dao
c422fee377 Get rid of o_rows_are_valid since we don't have headdim=16 anymore 2022-10-24 17:29:36 -07:00
Tri Dao
46fd2a20b2 Support all head dims that are multiples of 8, up to 128 2022-10-24 16:04:21 -07:00
Tri Dao
97e13de2b4 Cast q.get_device() to char to avoid compiler warning (narrowing) 2022-10-24 15:59:49 -07:00
Tri Dao
ed553e9238 Add Megatron attention implementation for benchmarking 2022-10-23 23:04:16 -07:00
Tri Dao
50ca23488d Add Triton implementation for benchmarking 2022-10-23 17:25:56 -07:00
Tri Dao
9e92a1f2d2 Attempt to use atomicCAS to replace atomicAdd(bfloat16) 2022-10-23 16:22:43 -07:00
Tri Dao
6731855b1f
Merge pull request #61 from robotcator/workflow
build wheel and upload to release
2022-10-23 12:52:51 -07:00
Tri Dao
fb88e5e4b3 Move benchmark utils, support AMP 2022-10-23 12:50:00 -07:00
Tri Dao
a5a8806d1a Split bwd on the seqlen_q dimension 2022-10-23 11:35:15 -07:00
Tri Dao
871db47941 Don't need to run configure for the forward pass 2022-10-21 18:22:27 -07:00
Tri Dao
7fc39832e2 Use block_size=128 for headdim=128 on SM80
Previously we were using block_size=256.
2022-10-21 13:19:54 -07:00
Tri Dao
a44f48df5a Split fwd on the seqlen_q dimension 2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6 Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1d0b41be3b
Merge pull request #60 from 201419/patch-1
fix typo in function mha_fwd
2022-10-17 09:38:48 -07:00
robotcator
35d589fa81 Merge branch 'main' of github.com:robotcator/flash-attention into workflow 2022-10-17 17:41:37 +08:00
robotcator
10d0745966 using tag trigger rather than push trigger 2022-10-17 17:39:08 +08:00