Phil Wang
|
b0eac3297f
|
allow for uploading to pypi
|
2022-11-15 13:26:55 -08:00 |
|
Tri Dao
|
2e33fc8e36
|
Add GPT and ViT models
|
2022-11-13 22:30:23 -08:00 |
|
Tri Dao
|
d4b320b31f
|
Add MLP, MHA, Block, Embedding modules
|
2022-11-13 22:06:44 -08:00 |
|
Tri Dao
|
fa6d1ce44f
|
Add fused_dense and dropout_add_layernorm CUDA extensions
|
2022-11-13 21:59:20 -08:00 |
|
Tri Dao
|
b92f2c3b67
|
Link to Colossal-AI's stable diffusion in usage.md
|
2022-11-13 20:49:05 -08:00 |
|
Tri Dao
|
343492ec30
|
Make nccl operations async in CrossEntropyLossParallel
|
2022-11-13 17:27:26 -08:00 |
|
Tri Dao
|
3dda4f76de
|
Update README
|
2022-11-13 16:52:40 -08:00 |
|
Tri Dao
|
79160a69a9
|
Add a page on where FlashAttention is being used
|
2022-11-13 16:40:18 -08:00 |
|
Tri Dao
|
a8fec99a9a
|
Skip flash_attn_split test
|
2022-11-13 12:27:48 -08:00 |
|
Tri Dao
|
9d3116addf
|
Don't enforce bitwise consistency for dq in race condition test
Since we could be parallelizing over seqlen_k
|
2022-11-13 12:21:51 -08:00 |
|
Tri Dao
|
7c9953815a
|
Add fused cross entropy loss
|
2022-11-12 21:58:41 -08:00 |
|
Tri Dao
|
55797f32c9
|
Remove RotaryEmbedding from FlashAttention module
To avoid import error if one doesn't have rotary_emb installed
|
2022-11-10 11:54:36 -08:00 |
|
Tri Dao
|
6998e0ecdb
|
Fix out-of-bound memory read
|
2022-11-09 09:34:14 -08:00 |
|
Tri Dao
|
908a5b2244
|
Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty)
|
2022-11-07 08:58:16 -08:00 |
|
Tri Dao
|
7479757191
|
Fix pipelining bug in Triton bwd with bias_type=matrix
|
2022-11-06 11:50:35 -08:00 |
|
Tri Dao
|
557781933d
|
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
|
2022-11-05 16:26:17 -07:00 |
|
Tri Dao
|
ca81f32e04
|
Implement rotary embedding in CUDA
|
2022-11-04 22:42:01 -07:00 |
|
Tri Dao
|
62025e1aff
|
Fix more race condition in Triton bwd when there's bias
|
2022-11-04 12:53:09 -07:00 |
|
Tri Dao
|
ff78ea4123
|
Fix race condition in Triton bwd when there's bias
|
2022-11-04 11:20:27 -07:00 |
|
Tri Dao
|
86862cfd7b
|
Implement attention bias for Triton version
|
2022-11-04 10:33:54 -07:00 |
|
Tri Dao
|
470010f59b
|
Fix race condition for Triton bwd for headdim 48 and 96
|
2022-11-03 15:52:40 -07:00 |
|
Tri Dao
|
aacc10fbab
|
Fix race condition in Triton bwd for non-po2 headdims
|
2022-11-02 07:32:54 -07:00 |
|
Tri Dao
|
1fb12afdfb
|
Avoid memcpy in the Triton bwd
|
2022-11-01 15:06:45 -07:00 |
|
Tri Dao
|
731f154de3
|
Fix race conditions in the Triton bwd for headdim=64
|
2022-11-01 15:05:55 -07:00 |
|
Tri Dao
|
9b0bc97872
|
Fix race condition in Triton fwd
|
2022-10-31 14:34:57 -07:00 |
|
Tri Dao
|
215930bce3
|
Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd
|
2022-10-31 01:41:49 -07:00 |
|
Tri Dao
|
4f81aff46e
|
Add debug_barrier for all headdims in Triton bwd
|
2022-10-31 01:25:02 -07:00 |
|
Tri Dao
|
bedcbd6a71
|
Disable some autotune configs that give wrong results in Triton bwd
|
2022-10-31 01:05:51 -07:00 |
|
Tri Dao
|
e78d509c64
|
[WIP] Support all head dimensions up to 128 in the Triton bwd
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
|
2022-10-31 00:46:22 -07:00 |
|
Tri Dao
|
008951f1d9
|
Support all head dimensions up to 128 in the Triton fwd
|
2022-10-30 22:10:48 -07:00 |
|
Tri Dao
|
b910bf14c1
|
Support arbitrary seqlens (both q & k) in Triton bwd
|
2022-10-30 21:50:53 -07:00 |
|
Tri Dao
|
dc55469355
|
Support arbitrary seqlen_k in Triton bwd
|
2022-10-30 21:26:26 -07:00 |
|
Tri Dao
|
d11341fd1a
|
Fix Triton fwd to support seqlen not multiples of 128
|
2022-10-30 19:05:47 -07:00 |
|
Tri Dao
|
b0c0db81f6
|
Implement FlashAttention in Triton
|
2022-10-30 18:09:11 -07:00 |
|
Tri Dao
|
c422fee377
|
Get rid of o_rows_are_valid since we don't have headdim=16 anymore
|
2022-10-24 17:29:36 -07:00 |
|
Tri Dao
|
46fd2a20b2
|
Support all head dims that are multiples of 8, up to 128
|
2022-10-24 16:04:21 -07:00 |
|
Tri Dao
|
97e13de2b4
|
Cast q.get_device() to char to avoid compiler warning (narrowing)
|
2022-10-24 15:59:49 -07:00 |
|
Tri Dao
|
ed553e9238
|
Add Megatron attention implementation for benchmarking
|
2022-10-23 23:04:16 -07:00 |
|
Tri Dao
|
50ca23488d
|
Add Triton implementation for benchmarking
|
2022-10-23 17:25:56 -07:00 |
|
Tri Dao
|
9e92a1f2d2
|
Attempt to use atomicCAS to replace atomicAdd(bfloat16)
|
2022-10-23 16:22:43 -07:00 |
|
Tri Dao
|
6731855b1f
|
Merge pull request #61 from robotcator/workflow
build wheel and upload to release
|
2022-10-23 12:52:51 -07:00 |
|
Tri Dao
|
fb88e5e4b3
|
Move benchmark utils, support AMP
|
2022-10-23 12:50:00 -07:00 |
|
Tri Dao
|
a5a8806d1a
|
Split bwd on the seqlen_q dimension
|
2022-10-23 11:35:15 -07:00 |
|
Tri Dao
|
871db47941
|
Don't need to run configure for the forward pass
|
2022-10-21 18:22:27 -07:00 |
|
Tri Dao
|
7fc39832e2
|
Use block_size=128 for headdim=128 on SM80
Previously we were using block_size=256.
|
2022-10-21 13:19:54 -07:00 |
|
Tri Dao
|
a44f48df5a
|
Split fwd on the seqlen_q dimension
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
1aa6d7d9b6
|
Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
1d0b41be3b
|
Merge pull request #60 from 201419/patch-1
fix typo in function mha_fwd
|
2022-10-17 09:38:48 -07:00 |
|
robotcator
|
35d589fa81
|
Merge branch 'main' of github.com:robotcator/flash-attention into workflow
|
2022-10-17 17:41:37 +08:00 |
|
robotcator
|
10d0745966
|
using tag trigger rather than push trigger
|
2022-10-17 17:39:08 +08:00 |
|