Commit Graph

176 Commits

Author SHA1 Message Date
robotcator
35d589fa81 Merge branch 'main' of github.com:robotcator/flash-attention into workflow 2022-10-17 17:41:37 +08:00
robotcator
10d0745966 using tag trigger rather than push trigger 2022-10-17 17:39:08 +08:00
YangShu
ff07250e8f
fix typo in function mha_fwd
as title.
2022-10-17 16:13:47 +08:00
Tri Dao
52fb4b729b Fix #54: set device for multi-GPU case 2022-10-16 12:51:26 -07:00
Tri Dao
1b9facacc3 Fix QKV interface to allocate output in Python 2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848 Implement attention kernel that splits the batch into two 2022-10-13 20:49:02 -07:00
Tri Dao
f515c77f25
Merge pull request #53 from robotcator/workflow
build wheel workflow
2022-10-09 22:26:22 -07:00
Tri Dao
8dd52b0788
Merge pull request #55 from ajfadam/main
remove numpy dependency
2022-10-06 10:29:38 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
88dc2040a0
Merge pull request #52 from bob80333/main
Make flash attention compile on Windows.
2022-10-04 21:25:37 -07:00
Eric Engelhart
2211db5fab Fixed switch statement, thanks @yocabon 2022-10-04 21:31:39 -04:00
Eric Engelhart
9b1b011bf6 Add C++17 arg to compiler, since C++17 features are used, fixes windows build 2022-10-04 21:31:39 -04:00
Eric Engelhart
9d7fd5b6e7 Replace BOOL_SWITCH with FP16_SWITCH to work around MSVC bug with constexpr variables and templates 2022-10-04 21:31:39 -04:00
Tri Dao
0c01568daf Only run backward test for d=128 on A100 2022-10-04 18:06:08 -07:00
robotcator
2c853fe821 add publish 2022-09-26 10:59:48 +08:00
robotcator
f7e7e912c1 add publish 2022-09-26 10:57:37 +08:00
Tri Dao
8166063a55 Use block_size=128 for d=128 on SM86 to avoid exceeding smem limit 2022-09-12 14:21:29 -07:00
Tri Dao
13403e8115 Relax assert to allow both bf16 and fp16 2022-09-11 12:09:43 -07:00
Tri Dao
64f42cd057 Change license from Apache 2.0 to BSD 2022-09-09 12:07:35 -07:00
Tri Dao
04fb198523
Merge pull request #43 from eric-tc-wong/patch-1
Update flash_attention.py
2022-09-06 14:37:31 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025 Add back need_weights in FlashMHA 2022-08-09 10:14:10 -07:00
Tri Dao
6cc7342575 Support index_first_axis with more than 2 dimensions 2022-08-05 09:48:16 -07:00
Tri Dao
713ea302d7 Allow headdim 128 in FlashMHA interface 2022-08-05 09:47:22 -07:00
Tri Dao
2ed471ecc4 Add tests for numerical error 2022-07-22 17:54:09 -04:00
Tri Dao
42f54d8840 Edit mention of Triton implementation
Phil Tillet suggests calling it "experimental".
2022-07-11 17:02:29 -07:00
Tri Dao
4577151ff8 Link to Triton implementation 2022-07-11 16:01:43 -07:00
Tri Dao
bc2c210254 Don't nest BOOL_SWITCH to work around gcc 7 bug 2022-07-11 10:28:46 -07:00
Tri Dao
d1fc80a3bb Link to IEEE Spectrum article on MLPerf 2022-07-10 12:11:46 -07:00
Tri Dao
1bbebccc0a Edit README to mention bf16 support 2022-07-09 23:34:29 -07:00
Tri Dao
de19de7ab1 Implement for bf16 2022-07-09 23:31:56 -07:00
Tri Dao
6a77a6da10 Refactor gemm_cl to template on either __half or __nv_bfloat16 2022-07-09 23:18:26 -07:00
Tri Dao
e518a4b327 Refactor to template on __half, implement bf16 util functions 2022-07-09 23:18:26 -07:00
Tri Dao
2dc1b205f6 Fix Illegal Memory Access bug in fwd when d=16 2022-07-09 23:17:14 -07:00
Tri Dao
5b838a8bef Apply dropout scaling to dQ and dK instead of to V (in bwd)
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
Tri Dao
a5559a0e75 Do P * dP (pointwise) in the bwd in fp32 instead of fp16 2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Tri Dao
01947bc93b
Merge pull request #18 from gahdritz/main
Slightly improve installation process
2022-07-02 23:33:14 -07:00
Gustaf
af4a9ce024 Add missing __init__.py 2022-07-03 02:04:55 -04:00
Gustaf
440e9c49f2 Add einops installation to setup.py 2022-07-03 02:04:24 -04:00
Tri Dao
f66603cb6f Support batch size > 64K by swapping grid.x and grid.y 2022-06-29 23:16:24 -07:00
Tri Dao
450b64fe44 Add README section on issues 2022-06-27 13:50:16 -07:00
Tri Dao
c0daa62eaa Add type check (fp16) in the forward pass 2022-06-26 11:41:30 -07:00
Tri Dao
ea38d3d261 Fix race condition in backward pass (smem_dq) 2022-06-25 18:02:30 -07:00
Tri Dao
eeca63a72a Bug fix: wrong smem_o write pointer for d=16 2022-06-25 15:18:33 -07:00
Dan Fu
765741c1ee More explanation 2022-06-14 11:55:14 -07:00
Dan Fu
2d5b2483b8 Speedup graph for A100, d128 2022-06-14 11:54:16 -07:00
Tri Dao
5d07483bbc Refactor Gmem code to store q, k, v pointers separately 2022-06-12 16:37:32 -07:00
Tri Dao
d3e6440958 Implement bwd for head dim 128 2022-06-11 17:52:36 -07:00
Tri Dao
0d854692c6 Implement fwd for head dim 128 2022-06-11 17:52:36 -07:00