Commit Graph

26 Commits

Author SHA1 Message Date
Tri Dao
2406f28805 Enable headdim 256 backward on consumer GPUs (Ampere, Ada) 2024-02-21 15:56:19 -08:00
Tri Dao
d9a5cb291c Fix dv = torch::empty_like(k) for mha_bwd_varlen as well 2024-02-10 01:03:00 -08:00
Brian Hirsh
2423cca3ad
fix backward for when query and key have different contiguity (#818) 2024-02-10 01:01:27 -08:00
Grigory Sizov
4687936413
Fix Windows build (#816) 2024-02-07 17:41:53 -08:00
Jeremy Reizenstein
0658e320f6
Preprocessor switches to control functionality (#788)
For faster and smaller builds in some simple cases,
provide switches to allow disabling
-backward
-alibi
-uneven k
-dropout
-local attention

Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
2024-01-29 20:44:23 -08:00
Tri Dao
54e80a3829 Implement page KV cache
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Tri Dao
ea8a25ca38 Remove configure in bwd kernel launch 2024-01-21 15:28:33 -08:00
Grigory Sizov
af01244ddd
Add split-kv and M<->H swap to varlen forward decoding attention (#754)
* Add split-k, M<->H to varseq path

* skip M<->H when dropout>0, fix LSE
2024-01-21 15:28:36 -08:00
Tri Dao
0842ec0da4 Don't dispatch to local if window size >= seqlen_k 2023-12-23 20:59:26 -08:00
Tri Dao
732654583c Implement deterministic backward (thanks to Meituan) 2023-12-23 17:57:36 -08:00
Tri Dao
5ab9b3667b Clean up alibi, implement non-causal alibi 2023-12-21 22:27:40 -08:00
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------

Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Jeremy Reizenstein
ce3e7280f8
Allow varlen_fwd to take optional seqused_k (#647)
Co-authored-by: bottler <bottler@users.noreply.github.com>
2023-11-27 00:41:23 -08:00
Tri Dao
db2f80692c Write zero to out / grad if seqlen_q or seqlen_k is zero 2023-11-19 22:20:01 -08:00
Tri Dao
e279bf8ed9 [Gen] Accept cache_batch_idx to index into the KV cache 2023-10-03 16:27:26 -07:00
Tri Dao
083e8f525f Implement local attention
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Tri Dao
65c234ed90 Don't over-allocate dq_accum in case of varlen 2023-09-24 00:36:07 -07:00
Tri Dao
2d8ea9a530 Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza) 2023-09-20 23:38:22 -07:00
Tri Dao
3250ff3d82 Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H) 2023-09-18 14:52:16 -07:00
Tri Dao
ccbb14f38e Implement rotary embedding in flash_attn_with_kvcache 2023-09-16 01:20:16 -07:00
Tri Dao
bb9beb3645 Remove some unused headers 2023-09-12 12:37:10 -07:00
Tri Dao
ee77b931b9 Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza) 2023-09-10 22:56:33 -07:00
Tri Dao
37c6e05406 Implement flash_attn_with_kvcache 2023-09-04 00:11:44 -07:00
Tri Dao
b1fbbd8337 Implement splitKV attention 2023-08-29 00:58:29 -07:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs (#386)
* Add RNG state to kernel launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Save seed and offset for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Single thread write to global mem

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1colblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1rowblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change forward c++ APIs to save RNG state for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change backward c++ APIs to set RNG state for bprop launcher

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python side API changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix; only save seeds instead of full offset

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Account for 3D grid size

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00