Commit Graph

170 Commits

Author SHA1 Message Date
Tri Dao
c7f32a8409 [CrossEntropy] Support precomputed LSE 2024-09-08 09:24:43 -07:00
Jay Shah
32792d37ec add missing if condition for key_padding_mask in test_util.py 2024-08-19 11:17:17 -07:00
Ying Zhang
53537da422 add a unittest 2024-08-17 13:23:50 -07:00
Tri Dao
299563626f Fix test with alibi and cache_leftpad 2024-07-23 02:04:15 -07:00
Tri Dao
751c762c9c Don't specialize for hdim 224 to speed up compilation 2024-07-23 00:13:54 -07:00
rocking
d8f104e97a
Support AMD ROCm on FlashAttention 2 (#1010)
* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------

Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
2024-07-22 21:34:37 -07:00
Ying Zhang
dfe1a59e4b
Add var-seq-len to FA3 fp16 / bf16 fwd (#1072)
* fwd var-seq-len

* fixes

* benchmark

* fixes

---------

Co-authored-by: Tri Dao <tridao@users.noreply.github.com>
2024-07-22 21:32:41 -07:00
Phil Wang
5f1ae4a34b
backwards for softcapping (#1033)
* check in the two ways of approaching backwards for softcapping, both functional

* prepare the softcap switch for backwards

* temporary

* cleanup to the way Tri prefers

* calculate dtanh when copying from scores -> dtanh Tensor

* no ternary operators allowed for constexpr, so just use some hack found online

* fix maybe_dtanh, restore some files

* restore another file

* move calculate_dtanh to utils and colocate with apply_softcap

* cleanup

* maybe last cleanup

* save for another pr

* remove a stray line

* fix spacing

* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
2024-07-21 23:25:46 -07:00
Tri Dao
40e534a7f6 Implement cache_leftpad 2024-07-11 08:17:15 -07:00
Tri Dao
d0787acc16 Relax dropout_fraction test 2024-07-10 11:49:40 -07:00
Tri Dao
dca6d89da4 Don't support softcap and dropout at the same time
These tests are failing so I'm just disabling this case for now
2024-07-10 11:23:12 -07:00
Tri Dao
81e01efd4b More typo fixes 2024-07-10 10:19:17 -07:00
Tri Dao
3d41db3e2c Only test backward if there's no softcapping 2024-07-10 00:27:45 -07:00
Nicolas Patry
8f873cc6ac
Implement softcapping. (#1025)
* Softcap v2 (fwd only).

* Some missing interface + remove overrides in tests.
2024-07-08 11:24:48 -07:00
muoshuosha
6df7e0a02e
Fix the varlen deterministic test (#1023)
Co-authored-by: moshuosha <moshuosha@qq.com>
2024-07-03 11:07:57 -07:00
cao lei
6a2a16e994
fix typo (#974) 2024-06-30 22:39:39 -07:00
Grigory Sizov
f816dee63c
Support unpadded LSE layout (#970)
* Support unpadded LSE layout.

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>

* Cleanup

* Fix unpadded LSE on split-kv path

* Fix formatting and comments

* Fix inline vs forceinline

---------

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
2024-06-27 02:38:13 -07:00
Ivan Komarov
f692b98d80
Fix spurious re-compilations of rotary_kernel (#911)
All integer parameters are specialized by default, so the two parameters
removed in this commit could lead to kernel re-compilation, even if
they were completely unused.
2024-04-05 13:40:41 -07:00
Grigory Sizov
2a15840f09
Enable paged attention in varlen forward (#831)
* Enable paged attention in varlen forward

* Format + fix padding
2024-03-15 00:48:19 -07:00
Tri Dao
2406f28805 Enable headdim 256 backward on consumer GPUs (Ampere, Ada) 2024-02-21 15:56:19 -08:00
Tri Dao
54e80a3829 Implement page KV cache
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Curtis "Fjord" Hawthorne
d8aacc510c
return z_loss (#768) 2024-01-21 15:23:41 -08:00
Tri Dao
10dad61277 apply_dropout now takes tensor of rowcol layout 2024-01-14 01:03:23 -08:00
Tri Dao
a7b66ae25a Simplify writing softmax to gmem 2024-01-13 00:25:04 -08:00
Tri Dao
f5b308e258 [LayerNorm] Rename layernorm.py -> layer_norm.py 2024-01-05 00:21:03 -08:00
Tri Dao
665b55e2e2 [LayerNorm] Implement parallel layer norm in Triton 2024-01-04 23:15:35 -08:00
Tri Dao
aa5c6438c5 [LayerNorm] Implement rowscale in Triton layernorm 2024-01-04 01:07:03 -08:00
Tri Dao
73df3be7d5 Add test for BTLM init 2023-12-25 15:16:27 -08:00
Tri Dao
7ffba9a501 Implement BTLM model 2023-12-24 20:35:12 -08:00
Tri Dao
3f7d5786ba Pass alibi slopes to flash_attn_with_kvcache during generation 2023-12-24 20:31:59 -08:00
Tri Dao
732654583c Implement deterministic backward (thanks to Meituan) 2023-12-23 17:57:36 -08:00
Tri Dao
2c7d7b7396 Implement norm head for Baichuan2 2023-12-22 16:55:40 -08:00
Tri Dao
c3b2196652 Add Alibi to MHA, test with Baichuan-13B 2023-12-21 22:49:55 -08:00
Tri Dao
5ab9b3667b Clean up alibi, implement non-causal alibi 2023-12-21 22:27:40 -08:00
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------

Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Tri Dao
cd089597fd [LayerNorm] Implement dropout in fused residual + LN/RMSNorm 2023-12-19 16:26:07 -08:00
Tri Dao
713bd3aa9a [CrossEntropy] Test longer sequences 2023-12-16 19:11:23 -08:00
Tri Dao
08124c8f9c [CrossEntropy] Implement logit_scale option 2023-12-16 18:39:37 -08:00
Tri Dao
9356a1c038 [LayerNorm] Implement layer_norm_linear 2023-11-30 21:46:07 -08:00
Tri Dao
aaa1474129 [CrossEntropy] Simplify the case of large vocab with Tensor Parallel 2023-11-19 23:19:36 -08:00
Shijie
abf04a56e1
fix flash ce mp large vocab (#673) 2023-11-19 23:01:07 -08:00
Tri Dao
017716451d [LayerNorm] Add postnorm residual + LayerNorm/RMSNorm in Triton 2023-11-13 22:37:55 -08:00
Tri Dao
79bd1a2d5d [LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton 2023-11-13 02:04:49 -08:00
Tri Dao
e279bf8ed9 [Gen] Accept cache_batch_idx to index into the KV cache 2023-10-03 16:27:26 -07:00
Tri Dao
083e8f525f Implement local attention
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Tri Dao
65c234ed90 Don't over-allocate dq_accum in case of varlen 2023-09-24 00:36:07 -07:00
Tri Dao
2d8ea9a530 Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza) 2023-09-20 23:38:22 -07:00
Tri Dao
0705d2718d [Llama] Fix some tests, add tests for Llama 2 and CodeLlama 2023-09-20 23:36:46 -07:00
Tri Dao
e0fbaa7016 [Gen] Simplify decode_speculative 2023-09-19 22:20:22 -07:00
Tri Dao
e6a8026489 [Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset 2023-09-19 22:20:22 -07:00