Commit Graph

35 Commits

Author SHA1 Message Date
Tri Dao
732654583c Implement deterministic backward (thanks to Meituan) 2023-12-23 17:57:36 -08:00
Tri Dao
5ab9b3667b Clean up alibi, implement non-causal alibi 2023-12-21 22:27:40 -08:00
Tri Dao
bc28eacc60 Format flash_attn_interface.py 2023-12-19 23:13:53 -08:00
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------

Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Tri Dao
d4a7c8ffbb [CI] Only compile for CUDA 11.8 & 12.2, MAX_JOBS=2,add torch-nightly 2023-11-27 16:21:28 -08:00
Jeremy Reizenstein
ce3e7280f8
Allow varlen_fwd to take optional seqused_k (#647)
Co-authored-by: bottler <bottler@users.noreply.github.com>
2023-11-27 00:41:23 -08:00
Tri Dao
e279bf8ed9 [Gen] Accept cache_batch_idx to index into the KV cache 2023-10-03 16:27:26 -07:00
Tri Dao
083e8f525f Implement local attention
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Tri Dao
ccbb14f38e Implement rotary embedding in flash_attn_with_kvcache 2023-09-16 01:20:16 -07:00
Tri Dao
ee77b931b9 Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza) 2023-09-10 22:56:33 -07:00
Tri Dao
fd20f16a4e Support cache_seqlens being integer 2023-09-05 11:27:48 -07:00
Tri Dao
37c6e05406 Implement flash_attn_with_kvcache 2023-09-04 00:11:44 -07:00
Tri Dao
9e5e8bc91e Change causal mask to be aligned to bottom-right instead of top-left 2023-08-24 23:41:07 -07:00
Tri Dao
d431f16751 Import torch before flash_attn_2_cuda 2023-08-19 21:07:33 -07:00
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Tri Dao
8f4cd4c16b [Docs] Fix docstring about Q nheads being divisible by KV nheads 2023-07-31 17:47:03 -07:00
Tri Dao
840f7925a0 [Docs] Fix mention of MQA/GQA in qkvpacked functions 2023-07-28 12:26:29 -10:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs (#386)
* Add RNG state to kernel launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Save seed and offset for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Single thread write to global mem

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1colblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1rowblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change forward c++ APIs to save RNG state for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change backward c++ APIs to set RNG state for bprop launcher

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python side API changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix; only save seeds instead of full offset

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Account for 3D grid size

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
b4cc152e97 Make sure dout is contiguous 2023-07-17 21:54:44 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Tri Dao
e8a0b4acdd [Doc] Change total -> total_q 2023-07-02 17:23:52 -07:00
Kirthi Shankar Sivamani
7d25a4ec4f Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-13 06:25:52 +00:00
Kirthi Shankar Sivamani
31018c5fa0 Support CUDA graph capture
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Kirthi Shankar Sivamani
b6aa059bbf Add option for deterministic execution 2023-03-30 18:23:35 -07:00
Tri Dao
88c4e5dbf6 Fix the case when dout is not contiguous 2022-12-13 13:58:17 -08:00
Tri Dao
557781933d Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
46fd2a20b2 Support all head dims that are multiples of 8, up to 128 2022-10-24 16:04:21 -07:00
Tri Dao
a5a8806d1a Split bwd on the seqlen_q dimension 2022-10-23 11:35:15 -07:00
Tri Dao
a44f48df5a Split fwd on the seqlen_q dimension 2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6 Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1b9facacc3 Fix QKV interface to allocate output in Python 2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848 Implement attention kernel that splits the batch into two 2022-10-13 20:49:02 -07:00
Tri Dao
a5559a0e75 Do P * dP (pointwise) in the bwd in fp32 instead of fp16 2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Tri Dao
5a61cb7729 Rename src -> flash_attn 2022-06-01 18:50:26 -07:00