Tri Dao
65f723bb9a
Split bwd into more .cu files to speed up compilation
2024-07-23 01:32:09 -07:00
Tri Dao
751c762c9c
Don't specialize for hdim 224 to speed up compilation
2024-07-23 00:13:54 -07:00
Driss Guessous
1c275eb070
Fix ima for split-kv kernel ( #1085 )
2024-07-22 22:19:46 -07:00
Jorge António
4df62e1440
catch typo ( #1058 )
2024-07-21 23:24:15 -07:00
Tri Dao
40e534a7f6
Implement cache_leftpad
2024-07-11 08:17:15 -07:00
Tri Dao
dca6d89da4
Don't support softcap and dropout at the same time
...
These tests are failing so I'm just disabling this case for now
2024-07-10 11:23:12 -07:00
Tri Dao
908511b2b6
Split into more .cu files to speed up compilation
2024-07-10 00:24:04 -07:00
Tri Dao
1d536d7de5
Minor cleanup of softcapping
2024-07-09 22:57:03 -07:00
Nicolas Patry
8f873cc6ac
Implement softcapping. ( #1025 )
...
* Softcap v2 (fwd only).
* Some missing interface + remove overrides in tests.
2024-07-08 11:24:48 -07:00
Nicolas Patry
5bf201966a
Fixing argument checking when using seqlenq_ngroups_swapped. ( #976 )
...
When user send `out` as a parameter of the function
`seqlenq_ngroups_swapped` with parameters that trigger,
the CHECK_SHAPE is incorrect (since q shape is modified.)
2024-06-30 22:39:22 -07:00
Grigory Sizov
f816dee63c
Support unpadded LSE layout ( #970 )
...
* Support unpadded LSE layout.
Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
* Cleanup
* Fix unpadded LSE on split-kv path
* Fix formatting and comments
* Fix inline vs forceinline
---------
Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
2024-06-27 02:38:13 -07:00
Tri Dao
9eb3d099c1
Transpose out when swapping seqlen_q and num_groups
2024-04-07 20:10:19 -07:00
Driss Guessous
4a73e903da
Add in, macrosf for defining __grid_constant__ ( #852 )
2024-03-15 00:48:54 -07:00
Grigory Sizov
2a15840f09
Enable paged attention in varlen forward ( #831 )
...
* Enable paged attention in varlen forward
* Format + fix padding
2024-03-15 00:48:19 -07:00
Tri Dao
2406f28805
Enable headdim 256 backward on consumer GPUs (Ampere, Ada)
2024-02-21 15:56:19 -08:00
Tri Dao
d9a5cb291c
Fix dv = torch::empty_like(k) for mha_bwd_varlen as well
2024-02-10 01:03:00 -08:00
Brian Hirsh
2423cca3ad
fix backward for when query and key have different contiguity ( #818 )
2024-02-10 01:01:27 -08:00
Grigory Sizov
4687936413
Fix Windows build ( #816 )
2024-02-07 17:41:53 -08:00
Jeremy Reizenstein
0658e320f6
Preprocessor switches to control functionality ( #788 )
...
For faster and smaller builds in some simple cases,
provide switches to allow disabling
-backward
-alibi
-uneven k
-dropout
-local attention
Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
2024-01-29 20:44:23 -08:00
Tri Dao
54e80a3829
Implement page KV cache
...
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Tri Dao
ea8a25ca38
Remove configure in bwd kernel launch
2024-01-21 15:28:33 -08:00
Grigory Sizov
af01244ddd
Add split-kv and M<->H swap to varlen forward decoding attention ( #754 )
...
* Add split-k, M<->H to varseq path
* skip M<->H when dropout>0, fix LSE
2024-01-21 15:28:36 -08:00
Tri Dao
0842ec0da4
Don't dispatch to local if window size >= seqlen_k
2023-12-23 20:59:26 -08:00
Tri Dao
732654583c
Implement deterministic backward (thanks to Meituan)
2023-12-23 17:57:36 -08:00
Tri Dao
5ab9b3667b
Clean up alibi, implement non-causal alibi
2023-12-21 22:27:40 -08:00
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
...
* hard-code alibi in fwd
* use params.h as hun_heads
* hard-code alibi in bwd
* add alibi on/off option
* compute alibi_start, ratio outside of kernels
* fix minor merge conflict
* add test_alibi.py
* change apply_alibi() location before masking
* add alibi in splitkv kernel
* fix backward func # of returns
* add out-of-bound check in apply_alibi()
* update test_alibi.py
* update test_alibi.py for kvcache
* simplify alibi parameter interface
* fix performance issue
by computing alibi outside of branch
* update test_flash_attn_varlen_func() for left padding
* implement alibi_slopes (b, nh) loading
* optimize apply_alibi() a bit
* update test cases for alibi_slopes loading
* reflect stylistic comments
* disable "seqlenq_ngroups_swapped" when using alibi
---------
Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Jeremy Reizenstein
ce3e7280f8
Allow varlen_fwd to take optional seqused_k ( #647 )
...
Co-authored-by: bottler <bottler@users.noreply.github.com>
2023-11-27 00:41:23 -08:00
Tri Dao
db2f80692c
Write zero to out / grad if seqlen_q or seqlen_k is zero
2023-11-19 22:20:01 -08:00
Tri Dao
e279bf8ed9
[Gen] Accept cache_batch_idx to index into the KV cache
2023-10-03 16:27:26 -07:00
Tri Dao
083e8f525f
Implement local attention
...
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Tri Dao
65c234ed90
Don't over-allocate dq_accum in case of varlen
2023-09-24 00:36:07 -07:00
Tri Dao
2d8ea9a530
Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza)
2023-09-20 23:38:22 -07:00
Tri Dao
3250ff3d82
Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H)
2023-09-18 14:52:16 -07:00
Tri Dao
ccbb14f38e
Implement rotary embedding in flash_attn_with_kvcache
2023-09-16 01:20:16 -07:00
Tri Dao
bb9beb3645
Remove some unused headers
2023-09-12 12:37:10 -07:00
Tri Dao
ee77b931b9
Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
2023-09-10 22:56:33 -07:00
Tri Dao
37c6e05406
Implement flash_attn_with_kvcache
2023-09-04 00:11:44 -07:00
Tri Dao
b1fbbd8337
Implement splitKV attention
2023-08-29 00:58:29 -07:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs ( #386 )
...
* Add RNG state to kernel launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Save seed and offset for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Single thread write to global mem
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1colblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1rowblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change forward c++ APIs to save RNG state for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change backward c++ APIs to set RNG state for bprop launcher
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fixes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Python side API changes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fix; only save seeds instead of full offset
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Account for 3D grid size
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
---------
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
4f285b3547
FlashAttention-2 release
2023-07-17 06:21:34 -07:00