Commit Graph

61 Commits

Author SHA1 Message Date
Ying Zhang
cdbbe844b1 minor changes to unpad_input test util func 2024-09-16 14:24:11 -07:00
Tri Dao
299563626f Fix test with alibi and cache_leftpad 2024-07-23 02:04:15 -07:00
Tri Dao
751c762c9c Don't specialize for hdim 224 to speed up compilation 2024-07-23 00:13:54 -07:00
Phil Wang
5f1ae4a34b
backwards for softcapping (#1033)
* check in the two ways of approaching backwards for softcapping, both functional

* prepare the softcap switch for backwards

* temporary

* cleanup to the way Tri prefers

* calculate dtanh when copying from scores -> dtanh Tensor

* no ternary operators allowed for constexpr, so just use some hack found online

* fix maybe_dtanh, restore some files

* restore another file

* move calculate_dtanh to utils and colocate with apply_softcap

* cleanup

* maybe last cleanup

* save for another pr

* remove a stray line

* fix spacing

* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
2024-07-21 23:25:46 -07:00
Tri Dao
40e534a7f6 Implement cache_leftpad 2024-07-11 08:17:15 -07:00
Tri Dao
d0787acc16 Relax dropout_fraction test 2024-07-10 11:49:40 -07:00
Tri Dao
dca6d89da4 Don't support softcap and dropout at the same time
These tests are failing so I'm just disabling this case for now
2024-07-10 11:23:12 -07:00
Tri Dao
81e01efd4b More typo fixes 2024-07-10 10:19:17 -07:00
Tri Dao
3d41db3e2c Only test backward if there's no softcapping 2024-07-10 00:27:45 -07:00
Nicolas Patry
8f873cc6ac
Implement softcapping. (#1025)
* Softcap v2 (fwd only).

* Some missing interface + remove overrides in tests.
2024-07-08 11:24:48 -07:00
muoshuosha
6df7e0a02e
Fix the varlen deterministic test (#1023)
Co-authored-by: moshuosha <moshuosha@qq.com>
2024-07-03 11:07:57 -07:00
cao lei
6a2a16e994
fix typo (#974) 2024-06-30 22:39:39 -07:00
Grigory Sizov
f816dee63c
Support unpadded LSE layout (#970)
* Support unpadded LSE layout.

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>

* Cleanup

* Fix unpadded LSE on split-kv path

* Fix formatting and comments

* Fix inline vs forceinline

---------

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
2024-06-27 02:38:13 -07:00
Grigory Sizov
2a15840f09
Enable paged attention in varlen forward (#831)
* Enable paged attention in varlen forward

* Format + fix padding
2024-03-15 00:48:19 -07:00
Tri Dao
2406f28805 Enable headdim 256 backward on consumer GPUs (Ampere, Ada) 2024-02-21 15:56:19 -08:00
Tri Dao
54e80a3829 Implement page KV cache
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Tri Dao
10dad61277 apply_dropout now takes tensor of rowcol layout 2024-01-14 01:03:23 -08:00
Tri Dao
a7b66ae25a Simplify writing softmax to gmem 2024-01-13 00:25:04 -08:00
Tri Dao
732654583c Implement deterministic backward (thanks to Meituan) 2023-12-23 17:57:36 -08:00
Tri Dao
5ab9b3667b Clean up alibi, implement non-causal alibi 2023-12-21 22:27:40 -08:00
Tri Dao
e279bf8ed9 [Gen] Accept cache_batch_idx to index into the KV cache 2023-10-03 16:27:26 -07:00
Tri Dao
083e8f525f Implement local attention
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Tri Dao
65c234ed90 Don't over-allocate dq_accum in case of varlen 2023-09-24 00:36:07 -07:00
Tri Dao
2d8ea9a530 Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza) 2023-09-20 23:38:22 -07:00
Tri Dao
3250ff3d82 Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H) 2023-09-18 14:52:16 -07:00
Tri Dao
ccbb14f38e Implement rotary embedding in flash_attn_with_kvcache 2023-09-16 01:20:16 -07:00
Tri Dao
56b7fc6ee0 Simplify the implementation of KVcache attn by appending KV first 2023-09-13 15:55:48 -07:00
Tri Dao
37c6e05406 Implement flash_attn_with_kvcache 2023-09-04 00:11:44 -07:00
Tri Dao
0c04943fa2 Require CUDA 11.6+, clean up setup.py 2023-09-03 21:24:56 -07:00
Tri Dao
b1fbbd8337 Implement splitKV attention 2023-08-29 00:58:29 -07:00
Tri Dao
9e5e8bc91e Change causal mask to be aligned to bottom-right instead of top-left 2023-08-24 23:41:07 -07:00
Tri Dao
0e8c46ae08 Run isort and black on test files 2023-08-18 20:59:35 -07:00
Tri Dao
c65b5106ac Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal 2023-08-16 15:12:36 -07:00
Tri Dao
3524e13c11 Update to Cutlass 3.1 2023-08-13 13:53:17 -07:00
Tri Dao
1c41d2b0e5 Fix race condition in bwd (overwriting sK) 2023-08-01 09:00:10 -07:00
Tri Dao
a4f148b6ab Fix masking of bwd when seqlen is not divisible by 128 2023-07-31 17:46:34 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Tri Dao
a8fec99a9a Skip flash_attn_split test 2022-11-13 12:27:48 -08:00
Tri Dao
9d3116addf Don't enforce bitwise consistency for dq in race condition test
Since we could be parallelizing over seqlen_k
2022-11-13 12:21:51 -08:00
Tri Dao
6998e0ecdb Fix out-of-bound memory read 2022-11-09 09:34:14 -08:00
Tri Dao
7479757191 Fix pipelining bug in Triton bwd with bias_type=matrix 2022-11-06 11:50:35 -08:00
Tri Dao
557781933d Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
ff78ea4123 Fix race condition in Triton bwd when there's bias 2022-11-04 11:20:27 -07:00
Tri Dao
86862cfd7b Implement attention bias for Triton version 2022-11-04 10:33:54 -07:00
Tri Dao
aacc10fbab Fix race condition in Triton bwd for non-po2 headdims 2022-11-02 07:32:54 -07:00
Tri Dao
1fb12afdfb Avoid memcpy in the Triton bwd 2022-11-01 15:06:45 -07:00
Tri Dao
9b0bc97872 Fix race condition in Triton fwd 2022-10-31 14:34:57 -07:00
Tri Dao
4f81aff46e Add debug_barrier for all headdims in Triton bwd 2022-10-31 01:25:02 -07:00
Tri Dao
e78d509c64 [WIP] Support all head dimensions up to 128 in the Triton bwd
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
2022-10-31 00:46:22 -07:00
Tri Dao
008951f1d9 Support all head dimensions up to 128 in the Triton fwd 2022-10-30 22:10:48 -07:00