Tri Dao
008951f1d9
Support all head dimensions up to 128 in the Triton fwd
2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1
Support arbitrary seqlens (both q & k) in Triton bwd
2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355
Support arbitrary seqlen_k in Triton bwd
2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a
Fix Triton fwd to support seqlen not multiples of 128
2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6
Implement FlashAttention in Triton
2022-10-30 18:09:11 -07:00
Tri Dao
c422fee377
Get rid of o_rows_are_valid since we don't have headdim=16 anymore
2022-10-24 17:29:36 -07:00
Tri Dao
46fd2a20b2
Support all head dims that are multiples of 8, up to 128
2022-10-24 16:04:21 -07:00
Tri Dao
97e13de2b4
Cast q.get_device() to char to avoid compiler warning (narrowing)
2022-10-24 15:59:49 -07:00
Tri Dao
ed553e9238
Add Megatron attention implementation for benchmarking
2022-10-23 23:04:16 -07:00
Tri Dao
50ca23488d
Add Triton implementation for benchmarking
2022-10-23 17:25:56 -07:00
Tri Dao
9e92a1f2d2
Attempt to use atomicCAS to replace atomicAdd(bfloat16)
2022-10-23 16:22:43 -07:00
Tri Dao
6731855b1f
Merge pull request #61 from robotcator/workflow
...
build wheel and upload to release
2022-10-23 12:52:51 -07:00
Tri Dao
fb88e5e4b3
Move benchmark utils, support AMP
2022-10-23 12:50:00 -07:00
Tri Dao
a5a8806d1a
Split bwd on the seqlen_q dimension
2022-10-23 11:35:15 -07:00
Tri Dao
871db47941
Don't need to run configure for the forward pass
2022-10-21 18:22:27 -07:00
Tri Dao
7fc39832e2
Use block_size=128 for headdim=128 on SM80
...
Previously we were using block_size=256.
2022-10-21 13:19:54 -07:00
Tri Dao
a44f48df5a
Split fwd on the seqlen_q dimension
2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6
Rework dropout to decouple forward and backward
...
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1d0b41be3b
Merge pull request #60 from 201419/patch-1
...
fix typo in function mha_fwd
2022-10-17 09:38:48 -07:00
robotcator
35d589fa81
Merge branch 'main' of github.com:robotcator/flash-attention into workflow
2022-10-17 17:41:37 +08:00
robotcator
10d0745966
using tag trigger rather than push trigger
2022-10-17 17:39:08 +08:00
YangShu
ff07250e8f
fix typo in function mha_fwd
...
as title.
2022-10-17 16:13:47 +08:00
Tri Dao
52fb4b729b
Fix #54 : set device for multi-GPU case
2022-10-16 12:51:26 -07:00
Tri Dao
1b9facacc3
Fix QKV interface to allocate output in Python
2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848
Implement attention kernel that splits the batch into two
2022-10-13 20:49:02 -07:00
Tri Dao
f515c77f25
Merge pull request #53 from robotcator/workflow
...
build wheel workflow
2022-10-09 22:26:22 -07:00
Tri Dao
8dd52b0788
Merge pull request #55 from ajfadam/main
...
remove numpy dependency
2022-10-06 10:29:38 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
...
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
88dc2040a0
Merge pull request #52 from bob80333/main
...
Make flash attention compile on Windows.
2022-10-04 21:25:37 -07:00
Eric Engelhart
2211db5fab
Fixed switch statement, thanks @yocabon
2022-10-04 21:31:39 -04:00
Eric Engelhart
9b1b011bf6
Add C++17 arg to compiler, since C++17 features are used, fixes windows build
2022-10-04 21:31:39 -04:00
Eric Engelhart
9d7fd5b6e7
Replace BOOL_SWITCH with FP16_SWITCH to work around MSVC bug with constexpr variables and templates
2022-10-04 21:31:39 -04:00
Tri Dao
0c01568daf
Only run backward test for d=128 on A100
2022-10-04 18:06:08 -07:00
robotcator
2c853fe821
add publish
2022-09-26 10:59:48 +08:00
robotcator
f7e7e912c1
add publish
2022-09-26 10:57:37 +08:00
Tri Dao
8166063a55
Use block_size=128 for d=128 on SM86 to avoid exceeding smem limit
2022-09-12 14:21:29 -07:00
Tri Dao
13403e8115
Relax assert to allow both bf16 and fp16
2022-09-11 12:09:43 -07:00
Tri Dao
64f42cd057
Change license from Apache 2.0 to BSD
2022-09-09 12:07:35 -07:00
Tri Dao
04fb198523
Merge pull request #43 from eric-tc-wong/patch-1
...
Update flash_attention.py
2022-09-06 14:37:31 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
...
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025
Add back need_weights in FlashMHA
2022-08-09 10:14:10 -07:00
Tri Dao
6cc7342575
Support index_first_axis with more than 2 dimensions
2022-08-05 09:48:16 -07:00
Tri Dao
713ea302d7
Allow headdim 128 in FlashMHA interface
2022-08-05 09:47:22 -07:00
Tri Dao
2ed471ecc4
Add tests for numerical error
2022-07-22 17:54:09 -04:00
Tri Dao
42f54d8840
Edit mention of Triton implementation
...
Phil Tillet suggests calling it "experimental".
2022-07-11 17:02:29 -07:00
Tri Dao
4577151ff8
Link to Triton implementation
2022-07-11 16:01:43 -07:00
Tri Dao
bc2c210254
Don't nest BOOL_SWITCH to work around gcc 7 bug
2022-07-11 10:28:46 -07:00
Tri Dao
d1fc80a3bb
Link to IEEE Spectrum article on MLPerf
2022-07-10 12:11:46 -07:00
Tri Dao
1bbebccc0a
Edit README to mention bf16 support
2022-07-09 23:34:29 -07:00
Tri Dao
de19de7ab1
Implement for bf16
2022-07-09 23:31:56 -07:00