Tri Dao
dc72d960a7
[CI] Install torch 2.3 using index
2024-01-30 14:32:29 -08:00
Tri Dao
daf37a9d8a
Bump to v2.5.1
2024-01-29 21:03:38 -08:00
Avelina9X
c94cd09744
Updated missing docstrings for args and returns in bert_padding.py ( #795 )
...
* Updated docstrings of bert_padding.py
Added docstrings for missing arguments in the unpad and pad methods.
* Update bert_padding.py
Fixed spelling mistakes
2024-01-27 09:16:25 -08:00
Tao He
204c3c6d1b
Fixes an error in comment ( #785 )
...
Signed-off-by: Tao He <sighingnow@gmail.com>
2024-01-23 12:38:29 -08:00
Tri Dao
197f2083a2
Bump to v2.5.0
2024-01-22 23:40:10 -08:00
Tri Dao
54e80a3829
Implement page KV cache
...
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Tri Dao
bdcae547c7
[LayerNorm] Don't exit early in the backward pass ( fix #781 )
2024-01-22 22:40:06 -08:00
Tri Dao
e43a4ceaab
[CI] Fix CUDA 12.2.2 compilation
2024-01-21 17:23:39 -08:00
Tri Dao
f9d7376126
Bump to v2.4.3
2024-01-21 17:14:37 -08:00
Curtis "Fjord" Hawthorne
d8aacc510c
return z_loss ( #768 )
2024-01-21 15:23:41 -08:00
Tri Dao
a7b66ae25a
Simplify writing softmax to gmem
2024-01-13 00:25:04 -08:00
Tri Dao
c9861a032d
[LayerNorm] Initialize mean and rstd tensor using x.device
2024-01-09 16:30:31 -08:00
Tri Dao
abbc131173
[LayerNorm] Switch from CUDA to Triton implementation
2024-01-05 00:31:17 -08:00
Tri Dao
f5b308e258
[LayerNorm] Rename layernorm.py -> layer_norm.py
2024-01-05 00:21:03 -08:00
Tri Dao
665b55e2e2
[LayerNorm] Implement parallel layer norm in Triton
2024-01-04 23:15:35 -08:00
Tri Dao
aa5c6438c5
[LayerNorm] Implement rowscale in Triton layernorm
2024-01-04 01:07:03 -08:00
jiaxingli
386e391117
Fix: implement deterministic backward in mha ( #748 )
...
* fix deterministic
* fix deterministic
2024-01-02 18:13:56 -08:00
Tri Dao
1a2c3e8c25
Bump to v2.4.2
2023-12-25 16:28:57 -08:00
Tri Dao
73df3be7d5
Add test for BTLM init
2023-12-25 15:16:27 -08:00
Tri Dao
7ffba9a501
Implement BTLM model
2023-12-24 20:35:12 -08:00
Tri Dao
2e29dacf0c
Implement muParam
2023-12-24 20:34:48 -08:00
Tri Dao
3f7d5786ba
Pass alibi slopes to flash_attn_with_kvcache during generation
2023-12-24 20:31:59 -08:00
Tri Dao
f844852485
Bump to v2.4.1
2023-12-23 21:00:39 -08:00
Tri Dao
732654583c
Implement deterministic backward (thanks to Meituan)
2023-12-23 17:57:36 -08:00
Tri Dao
2c7d7b7396
Implement norm head for Baichuan2
2023-12-22 16:55:40 -08:00
Tri Dao
68f178aa4b
[CI] Don't compile for python 3.7 pytorch 2.2
2023-12-22 10:10:02 -08:00
Tri Dao
7316277303
Bump to v2.4.0
2023-12-22 00:09:53 -08:00
Tri Dao
c3b2196652
Add Alibi to MHA, test with Baichuan-13B
2023-12-21 22:49:55 -08:00
Tri Dao
5ab9b3667b
Clean up alibi, implement non-causal alibi
2023-12-21 22:27:40 -08:00
Tri Dao
bc28eacc60
Format flash_attn_interface.py
2023-12-19 23:13:53 -08:00
Tri Dao
0a146185d6
[Gen] Remove minor dead code
2023-12-19 22:57:39 -08:00
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
...
* hard-code alibi in fwd
* use params.h as hun_heads
* hard-code alibi in bwd
* add alibi on/off option
* compute alibi_start, ratio outside of kernels
* fix minor merge conflict
* add test_alibi.py
* change apply_alibi() location before masking
* add alibi in splitkv kernel
* fix backward func # of returns
* add out-of-bound check in apply_alibi()
* update test_alibi.py
* update test_alibi.py for kvcache
* simplify alibi parameter interface
* fix performance issue
by computing alibi outside of branch
* update test_flash_attn_varlen_func() for left padding
* implement alibi_slopes (b, nh) loading
* optimize apply_alibi() a bit
* update test cases for alibi_slopes loading
* reflect stylistic comments
* disable "seqlenq_ngroups_swapped" when using alibi
---------
Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Tri Dao
cd089597fd
[LayerNorm] Implement dropout in fused residual + LN/RMSNorm
2023-12-19 16:26:07 -08:00
Tri Dao
08124c8f9c
[CrossEntropy] Implement logit_scale option
2023-12-16 18:39:37 -08:00
Tri Dao
9356a1c038
[LayerNorm] Implement layer_norm_linear
2023-11-30 21:46:07 -08:00
Tri Dao
92dd5703ec
Bump to v2.3.6
2023-11-27 16:23:39 -08:00
Tri Dao
d4a7c8ffbb
[CI] Only compile for CUDA 11.8 & 12.2, MAX_JOBS=2,add torch-nightly
2023-11-27 16:21:28 -08:00
Jeremy Reizenstein
ce3e7280f8
Allow varlen_fwd to take optional seqused_k ( #647 )
...
Co-authored-by: bottler <bottler@users.noreply.github.com>
2023-11-27 00:41:23 -08:00
Tri Dao
23b77c8148
Bump to v2.3.5
2023-11-26 19:08:28 -08:00
Tri Dao
2c3baba4a6
Bump to v2.3.4
2023-11-19 23:21:31 -08:00
Tri Dao
aaa1474129
[CrossEntropy] Simplify the case of large vocab with Tensor Parallel
2023-11-19 23:19:36 -08:00
Shijie
abf04a56e1
fix flash ce mp large vocab ( #673 )
2023-11-19 23:01:07 -08:00
Tri Dao
017716451d
[LayerNorm] Add postnorm residual + LayerNorm/RMSNorm in Triton
2023-11-13 22:37:55 -08:00
Tri Dao
79bd1a2d5d
[LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton
2023-11-13 02:04:49 -08:00
Antony Frolov
3566596ad8
Fix typo in RotaryEmbedding forward output type ( #666 )
2023-11-09 11:43:02 -08:00
Tri Dao
83aef842be
Bump to v2.3.3
2023-10-24 00:24:07 -07:00
Tri Dao
c79de85ffa
[CrossEntropy] Fix triton cross_entropy_loss IMA for >=2B elements
2023-10-24 00:17:34 -07:00
Tri Dao
7f31e7c16a
Bump to v2.3.2
2023-10-08 17:21:29 -07:00
Tri Dao
5e525a8dc8
[CI] Use official Pytorch 2.1, add CUDA 11.8 for Pytorch 2.1
2023-10-03 22:20:30 -07:00
Tri Dao
21c3b0d8f6
Bump to v2.3.1
2023-10-03 19:56:45 -07:00
Tri Dao
e279bf8ed9
[Gen] Accept cache_batch_idx to index into the KV cache
2023-10-03 16:27:26 -07:00
Tri Dao
601b4dc48d
Bump to v2.3.0
2023-09-26 22:08:29 -07:00
Tri Dao
083e8f525f
Implement local attention
...
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Katherine Crowson
4c8ff9154e
Fix NameError and typo in ApplyRotaryEmbQKV_ ( #569 )
2023-09-25 10:47:34 -07:00
Tri Dao
0a1d03c7ea
Bump to v2.2.5
2023-09-24 00:54:03 -07:00
Tri Dao
1879e089c7
Reduce number of templates for headdim > 128
2023-09-23 22:24:30 -07:00
Tri Dao
bff3147175
Re-enable compilation for Hopper
2023-09-21 23:55:25 -07:00
Yuchao Dai
187c2a0635
Fix E1136 ( #563 )
2023-09-21 11:48:23 -07:00
Tri Dao
229080b9d2
Bump to v2.2.4
2023-09-20 23:39:38 -07:00
Tri Dao
0705d2718d
[Llama] Fix some tests, add tests for Llama 2 and CodeLlama
2023-09-20 23:36:46 -07:00
Tri Dao
e0fbaa7016
[Gen] Simplify decode_speculative
2023-09-19 22:20:22 -07:00
Tri Dao
e6a8026489
[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset
2023-09-19 22:20:22 -07:00
Kevin Hu
42832575d4
Fix Llama GQA/MQA ( #546 )
...
* Fix llama MQA
* Fix permute shape
* Update llama.py
2023-09-19 22:15:59 -07:00
Tri Dao
dfe29f5e2b
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
2023-09-18 15:29:06 -07:00
Tri Dao
799f56fa90
Don't compile for Pytorch 2.1 on CUDA 12.1 due to nvcc segfaults
2023-09-17 22:15:38 -07:00
Tri Dao
c984208ddb
Set block size to 64 x 64 for kvcache to avoid nvcc segfaults
2023-09-17 16:14:58 -07:00
Tri Dao
8c8b4d36e1
Bump to v2.2.3
2023-09-16 01:47:01 -07:00
Tri Dao
ccbb14f38e
Implement rotary embedding in flash_attn_with_kvcache
2023-09-16 01:20:16 -07:00
Tri Dao
5400fdc4ac
[CE] Implement CrossEntropyLoss in Triton
2023-09-15 20:05:28 -07:00
Tri Dao
d0032700d1
Add tests for Pythia, GPT-JT, and RedPajama models
2023-09-13 01:10:39 -07:00
Tri Dao
08c295c043
Bump to v2.2.2
2023-09-10 23:48:12 -07:00
Tri Dao
ee77b931b9
Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
2023-09-10 22:56:33 -07:00
Kevin Hu
07005806ff
Add BigCode converters ( #532 )
2023-09-10 17:24:50 -07:00
Tri Dao
8a733cbd53
[Gen] Fix calling update_graph_cache in tests
2023-09-10 17:22:37 -07:00
Kevin Hu
4c91621a5e
Inverse state dict for BERT ( #527 )
2023-09-09 01:44:21 -07:00
Tri Dao
a86442f0f3
[Gen] Use flash_attn_with_kvcache in generation
2023-09-07 08:24:43 -07:00
Tri Dao
a1576ad1e8
Bump to v2.2.1
2023-09-06 02:19:55 -07:00
Tri Dao
9795159082
[Rotary] Set device before launching Triton kernel to avoid error
2023-09-05 21:29:03 -07:00
Tri Dao
6d673cd961
Bump to v2.2.0
2023-09-05 11:34:13 -07:00
Kyeongpil Kang
8e893f0950
Create __init__.py for ops/triton dir ( #516 )
2023-09-05 11:29:03 -07:00
Tri Dao
fd20f16a4e
Support cache_seqlens being integer
2023-09-05 11:27:48 -07:00
Tri Dao
913922cac5
[Gen] Refactor decoding function
2023-09-04 17:01:38 -07:00
Tri Dao
3557e0bb8f
[MLP] Implement SwiGLU with torch jiterator
2023-09-04 15:43:53 -07:00
Tri Dao
37c6e05406
Implement flash_attn_with_kvcache
2023-09-04 00:11:44 -07:00
Tri Dao
4976650f74
Set single threaded compilation for CUDA 12.2 so CI doesn't OOM
2023-09-03 23:42:55 -07:00
Tri Dao
6a89b2f121
Remove constexpr in launch template to fix CI compilation
2023-09-03 22:59:41 -07:00
Tri Dao
97ba7a62e9
Try switching back to Cutlass 3.2.0
2023-09-03 22:45:35 -07:00
Tri Dao
1dc1b6c8f2
Bump to v2.1.2
2023-09-03 22:23:05 -07:00
Tri Dao
798858f9f1
Fix test_baichuan
2023-09-03 21:01:37 -07:00
Tri Dao
7b33743a72
[Gen] Add back num_last_tokens in gpt.py
2023-09-03 20:44:40 -07:00
Tri Dao
b28ec236df
[Rotary] Implement varlen rotary
2023-09-03 17:57:10 -07:00
Tri Dao
861c82577d
[Rotary] Clean up rotary Triton implementation a bit
2023-09-03 16:41:17 -07:00
Tri Dao
1c523c1ce1
[Rotary] Speed up rotary kernel when interleaved=True
2023-09-03 16:24:37 -07:00
Tri Dao
de2949f37d
[Rotary] Pass max_seqlen from mha.py to rotary during inference
2023-09-03 11:37:06 -07:00
Tri Dao
942fcbf046
[Rotary] Implement rotary in Triton
2023-09-03 02:51:58 -07:00
dan_the_3rd
c9d4a816fa
Support LLaMa2 and CodeLLaMa ( #491 )
...
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:31:14 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding ( #490 )
...
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
GAOXinyu
0cb595ad94
[bugfix] handle_x not define when using checkpoint_lvl = 2 ( #502 )
...
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
2023-08-29 23:46:10 -07:00
Tri Dao
8a326bbc9e
[Gen] Minor fix to modify logits for top_p
2023-08-29 14:29:06 -07:00
Su Zhu
8f6f48d8a8
add unpad_input_for_concatenated_sequences ( #499 )
...
* add unpad_input_for_concatenated_sequences
* modify docstring
2023-08-29 02:23:56 -07:00