Antoni Viros
83e41b3ca4
Add custom ops for compatibility with PT Compile ( #1139 )
...
* Add custom ops for compatibility with PT Compile
* Add support for varlen functions too
* Add version checks for pytorch API
* Fix PT compile interfaces so it works e2e
* Make sure PT < 2.4 runs fine
* Fix python mistake
* Fix all the autograd magic issues
* typo on head_dim
* Fix deterministic test failures, remove unneeded detaches()
* remove test requires_grad
* Resolve all the pytorch versioning issues
* C++ and python refactor to improve padding management for torch.compile()
* Add improvements suggested by @anijain2305
2024-09-17 19:49:26 -07:00
youkaichao
ef3e358a25
remove lambda ( #1056 )
2024-07-21 23:24:38 -07:00
Tri Dao
898dd4bbf2
Pass seqused_k to _flash_attn_varlen_forward
2024-07-13 00:08:27 -07:00
Tri Dao
40e534a7f6
Implement cache_leftpad
2024-07-11 08:17:15 -07:00
Tri Dao
81e01efd4b
More typo fixes
2024-07-10 10:19:17 -07:00
Tri Dao
72e27c6320
Fix typo with softcapping
2024-07-10 00:33:52 -07:00
Phil Wang
f4628b43ec
missing commas and backwards return arguments ( #1032 )
...
* missing commas
* another fix
2024-07-09 10:56:29 -07:00
Nicolas Patry
8f873cc6ac
Implement softcapping. ( #1025 )
...
* Softcap v2 (fwd only).
* Some missing interface + remove overrides in tests.
2024-07-08 11:24:48 -07:00
Jianwei Dong
4e8d60069f
Add the return_softmax_lse parameter to the flash_attn_with_kvcache function to allow returning the logsumexp of the attention scores. ( #989 )
2024-07-08 08:29:40 -07:00
Grigory Sizov
f816dee63c
Support unpadded LSE layout ( #970 )
...
* Support unpadded LSE layout.
Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
* Cleanup
* Fix unpadded LSE on split-kv path
* Fix formatting and comments
* Fix inline vs forceinline
---------
Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
2024-06-27 02:38:13 -07:00
Grigory Sizov
2a15840f09
Enable paged attention in varlen forward ( #831 )
...
* Enable paged attention in varlen forward
* Format + fix padding
2024-03-15 00:48:19 -07:00
Tao He
204c3c6d1b
Fixes an error in comment ( #785 )
...
Signed-off-by: Tao He <sighingnow@gmail.com>
2024-01-23 12:38:29 -08:00
Tri Dao
54e80a3829
Implement page KV cache
...
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Tri Dao
a7b66ae25a
Simplify writing softmax to gmem
2024-01-13 00:25:04 -08:00
Tri Dao
732654583c
Implement deterministic backward (thanks to Meituan)
2023-12-23 17:57:36 -08:00
Tri Dao
5ab9b3667b
Clean up alibi, implement non-causal alibi
2023-12-21 22:27:40 -08:00
Tri Dao
bc28eacc60
Format flash_attn_interface.py
2023-12-19 23:13:53 -08:00
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
...
* hard-code alibi in fwd
* use params.h as hun_heads
* hard-code alibi in bwd
* add alibi on/off option
* compute alibi_start, ratio outside of kernels
* fix minor merge conflict
* add test_alibi.py
* change apply_alibi() location before masking
* add alibi in splitkv kernel
* fix backward func # of returns
* add out-of-bound check in apply_alibi()
* update test_alibi.py
* update test_alibi.py for kvcache
* simplify alibi parameter interface
* fix performance issue
by computing alibi outside of branch
* update test_flash_attn_varlen_func() for left padding
* implement alibi_slopes (b, nh) loading
* optimize apply_alibi() a bit
* update test cases for alibi_slopes loading
* reflect stylistic comments
* disable "seqlenq_ngroups_swapped" when using alibi
---------
Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Tri Dao
d4a7c8ffbb
[CI] Only compile for CUDA 11.8 & 12.2, MAX_JOBS=2,add torch-nightly
2023-11-27 16:21:28 -08:00
Jeremy Reizenstein
ce3e7280f8
Allow varlen_fwd to take optional seqused_k ( #647 )
...
Co-authored-by: bottler <bottler@users.noreply.github.com>
2023-11-27 00:41:23 -08:00
Tri Dao
e279bf8ed9
[Gen] Accept cache_batch_idx to index into the KV cache
2023-10-03 16:27:26 -07:00
Tri Dao
083e8f525f
Implement local attention
...
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Tri Dao
ccbb14f38e
Implement rotary embedding in flash_attn_with_kvcache
2023-09-16 01:20:16 -07:00
Tri Dao
ee77b931b9
Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
2023-09-10 22:56:33 -07:00
Tri Dao
fd20f16a4e
Support cache_seqlens being integer
2023-09-05 11:27:48 -07:00
Tri Dao
37c6e05406
Implement flash_attn_with_kvcache
2023-09-04 00:11:44 -07:00
Tri Dao
9e5e8bc91e
Change causal mask to be aligned to bottom-right instead of top-left
2023-08-24 23:41:07 -07:00
Tri Dao
d431f16751
Import torch before flash_attn_2_cuda
2023-08-19 21:07:33 -07:00
Tri Dao
f1a73d0740
Run isort and black on python files
2023-08-18 14:22:11 -07:00
Tri Dao
8f4cd4c16b
[Docs] Fix docstring about Q nheads being divisible by KV nheads
2023-07-31 17:47:03 -07:00
Tri Dao
840f7925a0
[Docs] Fix mention of MQA/GQA in qkvpacked functions
2023-07-28 12:26:29 -10:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs ( #386 )
...
* Add RNG state to kernel launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Save seed and offset for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Single thread write to global mem
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1colblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1rowblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change forward c++ APIs to save RNG state for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change backward c++ APIs to set RNG state for bprop launcher
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fixes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Python side API changes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fix; only save seeds instead of full offset
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Account for 3D grid size
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
---------
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
b4cc152e97
Make sure dout is contiguous
2023-07-17 21:54:44 -07:00
Tri Dao
4f285b3547
FlashAttention-2 release
2023-07-17 06:21:34 -07:00
Tri Dao
e8a0b4acdd
[Doc] Change total -> total_q
2023-07-02 17:23:52 -07:00
Kirthi Shankar Sivamani
7d25a4ec4f
Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-13 06:25:52 +00:00
Kirthi Shankar Sivamani
31018c5fa0
Support CUDA graph capture
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Kirthi Shankar Sivamani
b6aa059bbf
Add option for deterministic execution
2023-03-30 18:23:35 -07:00
Tri Dao
88c4e5dbf6
Fix the case when dout is not contiguous
2022-12-13 13:58:17 -08:00
Tri Dao
557781933d
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
...
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
46fd2a20b2
Support all head dims that are multiples of 8, up to 128
2022-10-24 16:04:21 -07:00
Tri Dao
a5a8806d1a
Split bwd on the seqlen_q dimension
2022-10-23 11:35:15 -07:00
Tri Dao
a44f48df5a
Split fwd on the seqlen_q dimension
2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6
Rework dropout to decouple forward and backward
...
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1b9facacc3
Fix QKV interface to allocate output in Python
2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848
Implement attention kernel that splits the batch into two
2022-10-13 20:49:02 -07:00
Tri Dao
a5559a0e75
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af
Implement cross attention
2022-07-03 17:48:12 -07:00
Tri Dao
5a61cb7729
Rename src -> flash_attn
2022-06-01 18:50:26 -07:00