Commit Graph

76 Commits

Author SHA1 Message Date
Tri Dao
1aa6d7d9b6 Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1d0b41be3b
Merge pull request #60 from 201419/patch-1
fix typo in function mha_fwd
2022-10-17 09:38:48 -07:00
YangShu
ff07250e8f
fix typo in function mha_fwd
as title.
2022-10-17 16:13:47 +08:00
Tri Dao
52fb4b729b Fix #54: set device for multi-GPU case 2022-10-16 12:51:26 -07:00
Tri Dao
1b9facacc3 Fix QKV interface to allocate output in Python 2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848 Implement attention kernel that splits the batch into two 2022-10-13 20:49:02 -07:00
Tri Dao
f515c77f25
Merge pull request #53 from robotcator/workflow
build wheel workflow
2022-10-09 22:26:22 -07:00
Tri Dao
8dd52b0788
Merge pull request #55 from ajfadam/main
remove numpy dependency
2022-10-06 10:29:38 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
88dc2040a0
Merge pull request #52 from bob80333/main
Make flash attention compile on Windows.
2022-10-04 21:25:37 -07:00
Eric Engelhart
2211db5fab Fixed switch statement, thanks @yocabon 2022-10-04 21:31:39 -04:00
Eric Engelhart
9b1b011bf6 Add C++17 arg to compiler, since C++17 features are used, fixes windows build 2022-10-04 21:31:39 -04:00
Eric Engelhart
9d7fd5b6e7 Replace BOOL_SWITCH with FP16_SWITCH to work around MSVC bug with constexpr variables and templates 2022-10-04 21:31:39 -04:00
Tri Dao
0c01568daf Only run backward test for d=128 on A100 2022-10-04 18:06:08 -07:00
robotcator
2c853fe821 add publish 2022-09-26 10:59:48 +08:00
robotcator
f7e7e912c1 add publish 2022-09-26 10:57:37 +08:00
Tri Dao
8166063a55 Use block_size=128 for d=128 on SM86 to avoid exceeding smem limit 2022-09-12 14:21:29 -07:00
Tri Dao
13403e8115 Relax assert to allow both bf16 and fp16 2022-09-11 12:09:43 -07:00
Tri Dao
64f42cd057 Change license from Apache 2.0 to BSD 2022-09-09 12:07:35 -07:00
Tri Dao
04fb198523
Merge pull request #43 from eric-tc-wong/patch-1
Update flash_attention.py
2022-09-06 14:37:31 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025 Add back need_weights in FlashMHA 2022-08-09 10:14:10 -07:00
Tri Dao
6cc7342575 Support index_first_axis with more than 2 dimensions 2022-08-05 09:48:16 -07:00
Tri Dao
713ea302d7 Allow headdim 128 in FlashMHA interface 2022-08-05 09:47:22 -07:00
Tri Dao
2ed471ecc4 Add tests for numerical error 2022-07-22 17:54:09 -04:00
Tri Dao
42f54d8840 Edit mention of Triton implementation
Phil Tillet suggests calling it "experimental".
2022-07-11 17:02:29 -07:00
Tri Dao
4577151ff8 Link to Triton implementation 2022-07-11 16:01:43 -07:00
Tri Dao
bc2c210254 Don't nest BOOL_SWITCH to work around gcc 7 bug 2022-07-11 10:28:46 -07:00
Tri Dao
d1fc80a3bb Link to IEEE Spectrum article on MLPerf 2022-07-10 12:11:46 -07:00
Tri Dao
1bbebccc0a Edit README to mention bf16 support 2022-07-09 23:34:29 -07:00
Tri Dao
de19de7ab1 Implement for bf16 2022-07-09 23:31:56 -07:00
Tri Dao
6a77a6da10 Refactor gemm_cl to template on either __half or __nv_bfloat16 2022-07-09 23:18:26 -07:00
Tri Dao
e518a4b327 Refactor to template on __half, implement bf16 util functions 2022-07-09 23:18:26 -07:00
Tri Dao
2dc1b205f6 Fix Illegal Memory Access bug in fwd when d=16 2022-07-09 23:17:14 -07:00
Tri Dao
5b838a8bef Apply dropout scaling to dQ and dK instead of to V (in bwd)
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
Tri Dao
a5559a0e75 Do P * dP (pointwise) in the bwd in fp32 instead of fp16 2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Tri Dao
01947bc93b
Merge pull request #18 from gahdritz/main
Slightly improve installation process
2022-07-02 23:33:14 -07:00
Gustaf
af4a9ce024 Add missing __init__.py 2022-07-03 02:04:55 -04:00
Gustaf
440e9c49f2 Add einops installation to setup.py 2022-07-03 02:04:24 -04:00
Tri Dao
f66603cb6f Support batch size > 64K by swapping grid.x and grid.y 2022-06-29 23:16:24 -07:00
Tri Dao
450b64fe44 Add README section on issues 2022-06-27 13:50:16 -07:00
Tri Dao
c0daa62eaa Add type check (fp16) in the forward pass 2022-06-26 11:41:30 -07:00
Tri Dao
ea38d3d261 Fix race condition in backward pass (smem_dq) 2022-06-25 18:02:30 -07:00
Tri Dao
eeca63a72a Bug fix: wrong smem_o write pointer for d=16 2022-06-25 15:18:33 -07:00
Dan Fu
765741c1ee More explanation 2022-06-14 11:55:14 -07:00
Dan Fu
2d5b2483b8 Speedup graph for A100, d128 2022-06-14 11:54:16 -07:00
Tri Dao
5d07483bbc Refactor Gmem code to store q, k, v pointers separately 2022-06-12 16:37:32 -07:00
Tri Dao
d3e6440958 Implement bwd for head dim 128 2022-06-11 17:52:36 -07:00
Tri Dao
0d854692c6 Implement fwd for head dim 128 2022-06-11 17:52:36 -07:00