Commit Graph

93 Commits

Author SHA1 Message Date
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------

Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Jeremy Reizenstein
ce3e7280f8
Allow varlen_fwd to take optional seqused_k (#647)
Co-authored-by: bottler <bottler@users.noreply.github.com>
2023-11-27 00:41:23 -08:00
Tri Dao
b4bf9cc1f3 Fix performance regression with causal 2023-11-26 19:07:25 -08:00
Tri Dao
db2f80692c Write zero to out / grad if seqlen_q or seqlen_k is zero 2023-11-19 22:20:01 -08:00
Driss Guessous
dc4b9ad6c4
add checks (#640) 2023-11-19 20:43:27 -08:00
Tri Dao
5a83425442 Change constexpr int to constexpr static int 2023-10-08 16:26:33 -07:00
Tri Dao
e279bf8ed9 [Gen] Accept cache_batch_idx to index into the KV cache 2023-10-03 16:27:26 -07:00
Tri Dao
083e8f525f Implement local attention
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Tri Dao
65c234ed90 Don't over-allocate dq_accum in case of varlen 2023-09-24 00:36:07 -07:00
Tri Dao
1879e089c7 Reduce number of templates for headdim > 128 2023-09-23 22:24:30 -07:00
Tri Dao
2d8ea9a530 Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza) 2023-09-20 23:38:22 -07:00
Tri Dao
3250ff3d82 Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H) 2023-09-18 14:52:16 -07:00
Tri Dao
43617deab9 Remove template for (IsEvenMN=T, IsEvenK=F) to speed up compilation 2023-09-18 12:21:36 -07:00
Tri Dao
c984208ddb Set block size to 64 x 64 for kvcache to avoid nvcc segfaults 2023-09-17 16:14:58 -07:00
Tri Dao
ccbb14f38e Implement rotary embedding in flash_attn_with_kvcache 2023-09-16 01:20:16 -07:00
Tri Dao
56b7fc6ee0 Simplify the implementation of KVcache attn by appending KV first 2023-09-13 15:55:48 -07:00
Tri Dao
bb9beb3645 Remove some unused headers 2023-09-12 12:37:10 -07:00
Tri Dao
ee77b931b9 Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza) 2023-09-10 22:56:33 -07:00
Tri Dao
37c6e05406 Implement flash_attn_with_kvcache 2023-09-04 00:11:44 -07:00
Tri Dao
6a89b2f121 Remove constexpr in launch template to fix CI compilation 2023-09-03 22:59:41 -07:00
Tri Dao
1dc1b6c8f2 Bump to v2.1.2 2023-09-03 22:23:05 -07:00
Tri Dao
5953c4f58c Remove unused sdPsum in dot_do_o function 2023-09-03 20:44:07 -07:00
Tri Dao
26d7d92f3d Fix splitKV combine function when local LSEs are all -inf 2023-09-03 11:39:09 -07:00
Sophia Wisdom
37e32febba
Remove commented out code in bwd (#512)
* Remove lots of comments

* Remove unused traits
2023-09-01 16:43:58 -07:00
Sophia Wisdom
dd8a754915
Remove old code in utils.h (#511) 2023-09-01 15:32:09 -07:00
Tri Dao
31920dda5f Fix typo with lse_max == -INFINITY 2023-08-29 21:48:59 -07:00
Tri Dao
b1fbbd8337 Implement splitKV attention 2023-08-29 00:58:29 -07:00
Tri Dao
7a983df742 Use generate_kernels.py script from Driss Guessous 2023-08-28 13:34:12 -07:00
Tri Dao
9e5e8bc91e Change causal mask to be aligned to bottom-right instead of top-left 2023-08-24 23:41:07 -07:00
BoxiangW
e07aa036db
Support flash attention 2 with causal masking when KV's seq length is longer than Q's seq length. (#436) 2023-08-24 16:42:34 -07:00
Tri Dao
bcfa7c9751 [FusedDense] Run black on fused_dense.py 2023-08-16 23:41:36 -07:00
Tri Dao
c65b5106ac Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal 2023-08-16 15:12:36 -07:00
Tri Dao
dbd7923782 Prepare for Cutlass 3.2 2023-08-13 15:24:32 -07:00
Tri Dao
3524e13c11 Update to Cutlass 3.1 2023-08-13 13:53:17 -07:00
Tri Dao
1c41d2b0e5 Fix race condition in bwd (overwriting sK) 2023-08-01 09:00:10 -07:00
Tri Dao
a4f148b6ab Fix masking of bwd when seqlen is not divisible by 128 2023-07-31 17:46:34 -07:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs (#386)
* Add RNG state to kernel launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Save seed and offset for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Single thread write to global mem

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1colblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1rowblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change forward c++ APIs to save RNG state for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change backward c++ APIs to set RNG state for bprop launcher

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python side API changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix; only save seeds instead of full offset

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Account for 3D grid size

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
9ee0ff1d9b Fix using dO stride for O, which can cause memory error in bwd 2023-07-20 17:39:57 -07:00
danthe3rd
538d570c96 Fix compile error on MSVC
See also: https://stackoverflow.com/questions/55136414/constexpr-variable-captured-inside-lambda-loses-its-constexpr-ness
2023-07-19 08:04:57 +00:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Tri Dao
ad113948a6 [Docs] Clearer error message for bwd d > 64, bump to v1.0.4 2023-04-26 09:19:48 -07:00
Kirthi Shankar Sivamani
45567a25a2 only 1 thread writes to global mem in fprop
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-15 06:09:41 +00:00
Kirthi Shankar Sivamani
7d25a4ec4f Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-13 06:25:52 +00:00
Kirthi Shankar Sivamani
31018c5fa0 Support CUDA graph capture
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Tri Dao
1b18f1b7a1 Support H100 2023-03-15 14:59:02 -07:00
Tri Dao
6b4a48218e [FA] Remove unused variable rng_engine_inputs 2023-01-25 15:32:40 -08:00
Tri Dao
a1f49a2b92 [Compilation] Change BOOL_SWITCH to fix Windows compilation
Follow xFormers's DISTPATCH_BOOL. Haven't tested it on Windows.
2023-01-06 14:40:58 -08:00
Tri Dao
8a2ece89f7 Simplify BOOL_SWITCH macro to fix compiling error on gcc 7 2022-12-06 14:38:32 -08:00
Tri Dao
9bc63d1e2d Fix typo in comments 2022-11-25 16:35:08 -08:00
Tri Dao
d95ee1a95d Speed up compilation by splitting into separate .cu files 2022-11-25 16:30:18 -08:00