Tri Dao
aacc10fbab
Fix race condition in Triton bwd for non-po2 headdims
2022-11-02 07:32:54 -07:00
Tri Dao
1fb12afdfb
Avoid memcpy in the Triton bwd
2022-11-01 15:06:45 -07:00
Tri Dao
731f154de3
Fix race conditions in the Triton bwd for headdim=64
2022-11-01 15:05:55 -07:00
Tri Dao
9b0bc97872
Fix race condition in Triton fwd
2022-10-31 14:34:57 -07:00
Tri Dao
215930bce3
Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd
2022-10-31 01:41:49 -07:00
Tri Dao
4f81aff46e
Add debug_barrier for all headdims in Triton bwd
2022-10-31 01:25:02 -07:00
Tri Dao
bedcbd6a71
Disable some autotune configs that give wrong results in Triton bwd
2022-10-31 01:05:51 -07:00
Tri Dao
e78d509c64
[WIP] Support all head dimensions up to 128 in the Triton bwd
...
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
2022-10-31 00:46:22 -07:00
Tri Dao
008951f1d9
Support all head dimensions up to 128 in the Triton fwd
2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1
Support arbitrary seqlens (both q & k) in Triton bwd
2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355
Support arbitrary seqlen_k in Triton bwd
2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a
Fix Triton fwd to support seqlen not multiples of 128
2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6
Implement FlashAttention in Triton
2022-10-30 18:09:11 -07:00
Tri Dao
c422fee377
Get rid of o_rows_are_valid since we don't have headdim=16 anymore
2022-10-24 17:29:36 -07:00
Tri Dao
46fd2a20b2
Support all head dims that are multiples of 8, up to 128
2022-10-24 16:04:21 -07:00
Tri Dao
97e13de2b4
Cast q.get_device() to char to avoid compiler warning (narrowing)
2022-10-24 15:59:49 -07:00
Tri Dao
ed553e9238
Add Megatron attention implementation for benchmarking
2022-10-23 23:04:16 -07:00
Tri Dao
50ca23488d
Add Triton implementation for benchmarking
2022-10-23 17:25:56 -07:00
Tri Dao
9e92a1f2d2
Attempt to use atomicCAS to replace atomicAdd(bfloat16)
2022-10-23 16:22:43 -07:00
Tri Dao
6731855b1f
Merge pull request #61 from robotcator/workflow
...
build wheel and upload to release
2022-10-23 12:52:51 -07:00
Tri Dao
fb88e5e4b3
Move benchmark utils, support AMP
2022-10-23 12:50:00 -07:00
Tri Dao
a5a8806d1a
Split bwd on the seqlen_q dimension
2022-10-23 11:35:15 -07:00
Tri Dao
871db47941
Don't need to run configure for the forward pass
2022-10-21 18:22:27 -07:00
Tri Dao
7fc39832e2
Use block_size=128 for headdim=128 on SM80
...
Previously we were using block_size=256.
2022-10-21 13:19:54 -07:00
Tri Dao
a44f48df5a
Split fwd on the seqlen_q dimension
2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6
Rework dropout to decouple forward and backward
...
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1d0b41be3b
Merge pull request #60 from 201419/patch-1
...
fix typo in function mha_fwd
2022-10-17 09:38:48 -07:00
robotcator
35d589fa81
Merge branch 'main' of github.com:robotcator/flash-attention into workflow
2022-10-17 17:41:37 +08:00
robotcator
10d0745966
using tag trigger rather than push trigger
2022-10-17 17:39:08 +08:00
YangShu
ff07250e8f
fix typo in function mha_fwd
...
as title.
2022-10-17 16:13:47 +08:00
Tri Dao
52fb4b729b
Fix #54 : set device for multi-GPU case
2022-10-16 12:51:26 -07:00
Tri Dao
1b9facacc3
Fix QKV interface to allocate output in Python
2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848
Implement attention kernel that splits the batch into two
2022-10-13 20:49:02 -07:00
Tri Dao
f515c77f25
Merge pull request #53 from robotcator/workflow
...
build wheel workflow
2022-10-09 22:26:22 -07:00
Tri Dao
8dd52b0788
Merge pull request #55 from ajfadam/main
...
remove numpy dependency
2022-10-06 10:29:38 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
...
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
88dc2040a0
Merge pull request #52 from bob80333/main
...
Make flash attention compile on Windows.
2022-10-04 21:25:37 -07:00
Eric Engelhart
2211db5fab
Fixed switch statement, thanks @yocabon
2022-10-04 21:31:39 -04:00
Eric Engelhart
9b1b011bf6
Add C++17 arg to compiler, since C++17 features are used, fixes windows build
2022-10-04 21:31:39 -04:00
Eric Engelhart
9d7fd5b6e7
Replace BOOL_SWITCH with FP16_SWITCH to work around MSVC bug with constexpr variables and templates
2022-10-04 21:31:39 -04:00
Tri Dao
0c01568daf
Only run backward test for d=128 on A100
2022-10-04 18:06:08 -07:00
robotcator
2c853fe821
add publish
2022-09-26 10:59:48 +08:00
robotcator
f7e7e912c1
add publish
2022-09-26 10:57:37 +08:00
Tri Dao
8166063a55
Use block_size=128 for d=128 on SM86 to avoid exceeding smem limit
2022-09-12 14:21:29 -07:00
Tri Dao
13403e8115
Relax assert to allow both bf16 and fp16
2022-09-11 12:09:43 -07:00
Tri Dao
64f42cd057
Change license from Apache 2.0 to BSD
2022-09-09 12:07:35 -07:00
Tri Dao
04fb198523
Merge pull request #43 from eric-tc-wong/patch-1
...
Update flash_attention.py
2022-09-06 14:37:31 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
...
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025
Add back need_weights in FlashMHA
2022-08-09 10:14:10 -07:00
Tri Dao
6cc7342575
Support index_first_axis with more than 2 dimensions
2022-08-05 09:48:16 -07:00