Commit Graph

329 Commits

Author SHA1 Message Date
Tri Dao
abbc131173 [LayerNorm] Switch from CUDA to Triton implementation 2024-01-05 00:31:17 -08:00
Tri Dao
f5b308e258 [LayerNorm] Rename layernorm.py -> layer_norm.py 2024-01-05 00:21:03 -08:00
Tri Dao
665b55e2e2 [LayerNorm] Implement parallel layer norm in Triton 2024-01-04 23:15:35 -08:00
Tri Dao
aa5c6438c5 [LayerNorm] Implement rowscale in Triton layernorm 2024-01-04 01:07:03 -08:00
jiaxingli
386e391117
Fix: implement deterministic backward in mha (#748)
* fix deterministic

* fix deterministic
2024-01-02 18:13:56 -08:00
Tri Dao
1a2c3e8c25 Bump to v2.4.2 2023-12-25 16:28:57 -08:00
Tri Dao
73df3be7d5 Add test for BTLM init 2023-12-25 15:16:27 -08:00
Tri Dao
7ffba9a501 Implement BTLM model 2023-12-24 20:35:12 -08:00
Tri Dao
2e29dacf0c Implement muParam 2023-12-24 20:34:48 -08:00
Tri Dao
3f7d5786ba Pass alibi slopes to flash_attn_with_kvcache during generation 2023-12-24 20:31:59 -08:00
Tri Dao
f844852485 Bump to v2.4.1 2023-12-23 21:00:39 -08:00
Tri Dao
732654583c Implement deterministic backward (thanks to Meituan) 2023-12-23 17:57:36 -08:00
Tri Dao
2c7d7b7396 Implement norm head for Baichuan2 2023-12-22 16:55:40 -08:00
Tri Dao
68f178aa4b [CI] Don't compile for python 3.7 pytorch 2.2 2023-12-22 10:10:02 -08:00
Tri Dao
7316277303 Bump to v2.4.0 2023-12-22 00:09:53 -08:00
Tri Dao
c3b2196652 Add Alibi to MHA, test with Baichuan-13B 2023-12-21 22:49:55 -08:00
Tri Dao
5ab9b3667b Clean up alibi, implement non-causal alibi 2023-12-21 22:27:40 -08:00
Tri Dao
bc28eacc60 Format flash_attn_interface.py 2023-12-19 23:13:53 -08:00
Tri Dao
0a146185d6 [Gen] Remove minor dead code 2023-12-19 22:57:39 -08:00
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------

Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Tri Dao
cd089597fd [LayerNorm] Implement dropout in fused residual + LN/RMSNorm 2023-12-19 16:26:07 -08:00
Tri Dao
08124c8f9c [CrossEntropy] Implement logit_scale option 2023-12-16 18:39:37 -08:00
Tri Dao
9356a1c038 [LayerNorm] Implement layer_norm_linear 2023-11-30 21:46:07 -08:00
Tri Dao
92dd5703ec Bump to v2.3.6 2023-11-27 16:23:39 -08:00
Tri Dao
d4a7c8ffbb [CI] Only compile for CUDA 11.8 & 12.2, MAX_JOBS=2,add torch-nightly 2023-11-27 16:21:28 -08:00
Jeremy Reizenstein
ce3e7280f8
Allow varlen_fwd to take optional seqused_k (#647)
Co-authored-by: bottler <bottler@users.noreply.github.com>
2023-11-27 00:41:23 -08:00
Tri Dao
23b77c8148 Bump to v2.3.5 2023-11-26 19:08:28 -08:00
Tri Dao
2c3baba4a6 Bump to v2.3.4 2023-11-19 23:21:31 -08:00
Tri Dao
aaa1474129 [CrossEntropy] Simplify the case of large vocab with Tensor Parallel 2023-11-19 23:19:36 -08:00
Shijie
abf04a56e1
fix flash ce mp large vocab (#673) 2023-11-19 23:01:07 -08:00
Tri Dao
017716451d [LayerNorm] Add postnorm residual + LayerNorm/RMSNorm in Triton 2023-11-13 22:37:55 -08:00
Tri Dao
79bd1a2d5d [LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton 2023-11-13 02:04:49 -08:00
Antony Frolov
3566596ad8
Fix typo in RotaryEmbedding forward output type (#666) 2023-11-09 11:43:02 -08:00
Tri Dao
83aef842be Bump to v2.3.3 2023-10-24 00:24:07 -07:00
Tri Dao
c79de85ffa [CrossEntropy] Fix triton cross_entropy_loss IMA for >=2B elements 2023-10-24 00:17:34 -07:00
Tri Dao
7f31e7c16a Bump to v2.3.2 2023-10-08 17:21:29 -07:00
Tri Dao
5e525a8dc8 [CI] Use official Pytorch 2.1, add CUDA 11.8 for Pytorch 2.1 2023-10-03 22:20:30 -07:00
Tri Dao
21c3b0d8f6 Bump to v2.3.1 2023-10-03 19:56:45 -07:00
Tri Dao
e279bf8ed9 [Gen] Accept cache_batch_idx to index into the KV cache 2023-10-03 16:27:26 -07:00
Tri Dao
601b4dc48d Bump to v2.3.0 2023-09-26 22:08:29 -07:00
Tri Dao
083e8f525f Implement local attention
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Katherine Crowson
4c8ff9154e
Fix NameError and typo in ApplyRotaryEmbQKV_ (#569) 2023-09-25 10:47:34 -07:00
Tri Dao
0a1d03c7ea Bump to v2.2.5 2023-09-24 00:54:03 -07:00
Tri Dao
1879e089c7 Reduce number of templates for headdim > 128 2023-09-23 22:24:30 -07:00
Tri Dao
bff3147175 Re-enable compilation for Hopper 2023-09-21 23:55:25 -07:00
Yuchao Dai
187c2a0635
Fix E1136 (#563) 2023-09-21 11:48:23 -07:00
Tri Dao
229080b9d2 Bump to v2.2.4 2023-09-20 23:39:38 -07:00
Tri Dao
0705d2718d [Llama] Fix some tests, add tests for Llama 2 and CodeLlama 2023-09-20 23:36:46 -07:00
Tri Dao
e0fbaa7016 [Gen] Simplify decode_speculative 2023-09-19 22:20:22 -07:00
Tri Dao
e6a8026489 [Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset 2023-09-19 22:20:22 -07:00
Kevin Hu
42832575d4
Fix Llama GQA/MQA (#546)
* Fix llama MQA

* Fix permute shape

* Update llama.py
2023-09-19 22:15:59 -07:00
Tri Dao
dfe29f5e2b [Gen] Don't use ft_attention, use flash_attn_with_kvcache instead 2023-09-18 15:29:06 -07:00
Tri Dao
799f56fa90 Don't compile for Pytorch 2.1 on CUDA 12.1 due to nvcc segfaults 2023-09-17 22:15:38 -07:00
Tri Dao
c984208ddb Set block size to 64 x 64 for kvcache to avoid nvcc segfaults 2023-09-17 16:14:58 -07:00
Tri Dao
8c8b4d36e1 Bump to v2.2.3 2023-09-16 01:47:01 -07:00
Tri Dao
ccbb14f38e Implement rotary embedding in flash_attn_with_kvcache 2023-09-16 01:20:16 -07:00
Tri Dao
5400fdc4ac [CE] Implement CrossEntropyLoss in Triton 2023-09-15 20:05:28 -07:00
Tri Dao
d0032700d1 Add tests for Pythia, GPT-JT, and RedPajama models 2023-09-13 01:10:39 -07:00
Tri Dao
08c295c043 Bump to v2.2.2 2023-09-10 23:48:12 -07:00
Tri Dao
ee77b931b9 Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza) 2023-09-10 22:56:33 -07:00
Kevin Hu
07005806ff
Add BigCode converters (#532) 2023-09-10 17:24:50 -07:00
Tri Dao
8a733cbd53 [Gen] Fix calling update_graph_cache in tests 2023-09-10 17:22:37 -07:00
Kevin Hu
4c91621a5e
Inverse state dict for BERT (#527) 2023-09-09 01:44:21 -07:00
Tri Dao
a86442f0f3 [Gen] Use flash_attn_with_kvcache in generation 2023-09-07 08:24:43 -07:00
Tri Dao
a1576ad1e8 Bump to v2.2.1 2023-09-06 02:19:55 -07:00
Tri Dao
9795159082 [Rotary] Set device before launching Triton kernel to avoid error 2023-09-05 21:29:03 -07:00
Tri Dao
6d673cd961 Bump to v2.2.0 2023-09-05 11:34:13 -07:00
Kyeongpil Kang
8e893f0950
Create __init__.py for ops/triton dir (#516) 2023-09-05 11:29:03 -07:00
Tri Dao
fd20f16a4e Support cache_seqlens being integer 2023-09-05 11:27:48 -07:00
Tri Dao
913922cac5 [Gen] Refactor decoding function 2023-09-04 17:01:38 -07:00
Tri Dao
3557e0bb8f [MLP] Implement SwiGLU with torch jiterator 2023-09-04 15:43:53 -07:00
Tri Dao
37c6e05406 Implement flash_attn_with_kvcache 2023-09-04 00:11:44 -07:00
Tri Dao
4976650f74 Set single threaded compilation for CUDA 12.2 so CI doesn't OOM 2023-09-03 23:42:55 -07:00
Tri Dao
6a89b2f121 Remove constexpr in launch template to fix CI compilation 2023-09-03 22:59:41 -07:00
Tri Dao
97ba7a62e9 Try switching back to Cutlass 3.2.0 2023-09-03 22:45:35 -07:00
Tri Dao
1dc1b6c8f2 Bump to v2.1.2 2023-09-03 22:23:05 -07:00
Tri Dao
798858f9f1 Fix test_baichuan 2023-09-03 21:01:37 -07:00
Tri Dao
7b33743a72 [Gen] Add back num_last_tokens in gpt.py 2023-09-03 20:44:40 -07:00
Tri Dao
b28ec236df [Rotary] Implement varlen rotary 2023-09-03 17:57:10 -07:00
Tri Dao
861c82577d [Rotary] Clean up rotary Triton implementation a bit 2023-09-03 16:41:17 -07:00
Tri Dao
1c523c1ce1 [Rotary] Speed up rotary kernel when interleaved=True 2023-09-03 16:24:37 -07:00
Tri Dao
de2949f37d [Rotary] Pass max_seqlen from mha.py to rotary during inference 2023-09-03 11:37:06 -07:00
Tri Dao
942fcbf046 [Rotary] Implement rotary in Triton 2023-09-03 02:51:58 -07:00
dan_the_3rd
c9d4a816fa
Support LLaMa2 and CodeLLaMa (#491)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:31:14 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
GAOXinyu
0cb595ad94
[bugfix] handle_x not define when using checkpoint_lvl = 2 (#502)
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
2023-08-29 23:46:10 -07:00
Tri Dao
8a326bbc9e [Gen] Minor fix to modify logits for top_p 2023-08-29 14:29:06 -07:00
Su Zhu
8f6f48d8a8
add unpad_input_for_concatenated_sequences (#499)
* add unpad_input_for_concatenated_sequences

* modify docstring
2023-08-29 02:23:56 -07:00
Tri Dao
757058d4d3 Update Cutlass to v3.2.0 2023-08-27 23:47:28 -07:00
Tri Dao
9f42cb6e7a [Gen] Clone logits before returning when cg=True 2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0 [GPT] Generalize last_token_only arg to num_last_tokens 2023-08-26 20:47:53 -07:00
Tri Dao
7a3bd55f1a [Gen] Fix decode function not using top_p during iterative decoding 2023-08-26 15:14:41 -07:00
Tri Dao
847abe653c [Gen] Refactor decode function a bit 2023-08-26 14:47:25 -07:00
Tri Dao
a2974e850a Change causal for CrossAttention in mha.py to align to bottom right 2023-08-26 12:57:33 -07:00
Tri Dao
73bd3f3bbb Move pyproject.toml to flash-attn and tests dir to avoid PEP 517 2023-08-25 15:05:28 -07:00
Tri Dao
9e5e8bc91e Change causal mask to be aligned to bottom-right instead of top-left 2023-08-24 23:41:07 -07:00
Aman Gupta Karmani
e0b09891c6
add llama support to GPTPreTrainedModel.from_pretrained (#479) 2023-08-24 16:31:16 -07:00
Tri Dao
6711b3bc40 Bump version to 2.0.9 2023-08-22 00:21:14 -07:00
Tri Dao
ef6d8c75d9 [GPT] Fix loading weights from HF hub 2023-08-21 22:56:02 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425) 2023-08-21 11:05:06 -07:00