Tri Dao
|
0a1d03c7ea
|
Bump to v2.2.5
|
2023-09-24 00:54:03 -07:00 |
|
Tri Dao
|
1879e089c7
|
Reduce number of templates for headdim > 128
|
2023-09-23 22:24:30 -07:00 |
|
Tri Dao
|
bff3147175
|
Re-enable compilation for Hopper
|
2023-09-21 23:55:25 -07:00 |
|
Yuchao Dai
|
187c2a0635
|
Fix E1136 (#563)
|
2023-09-21 11:48:23 -07:00 |
|
Tri Dao
|
229080b9d2
|
Bump to v2.2.4
|
2023-09-20 23:39:38 -07:00 |
|
Tri Dao
|
0705d2718d
|
[Llama] Fix some tests, add tests for Llama 2 and CodeLlama
|
2023-09-20 23:36:46 -07:00 |
|
Tri Dao
|
e0fbaa7016
|
[Gen] Simplify decode_speculative
|
2023-09-19 22:20:22 -07:00 |
|
Tri Dao
|
e6a8026489
|
[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset
|
2023-09-19 22:20:22 -07:00 |
|
Kevin Hu
|
42832575d4
|
Fix Llama GQA/MQA (#546)
* Fix llama MQA
* Fix permute shape
* Update llama.py
|
2023-09-19 22:15:59 -07:00 |
|
Tri Dao
|
dfe29f5e2b
|
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
|
2023-09-18 15:29:06 -07:00 |
|
Tri Dao
|
799f56fa90
|
Don't compile for Pytorch 2.1 on CUDA 12.1 due to nvcc segfaults
|
2023-09-17 22:15:38 -07:00 |
|
Tri Dao
|
c984208ddb
|
Set block size to 64 x 64 for kvcache to avoid nvcc segfaults
|
2023-09-17 16:14:58 -07:00 |
|
Tri Dao
|
8c8b4d36e1
|
Bump to v2.2.3
|
2023-09-16 01:47:01 -07:00 |
|
Tri Dao
|
ccbb14f38e
|
Implement rotary embedding in flash_attn_with_kvcache
|
2023-09-16 01:20:16 -07:00 |
|
Tri Dao
|
5400fdc4ac
|
[CE] Implement CrossEntropyLoss in Triton
|
2023-09-15 20:05:28 -07:00 |
|
Tri Dao
|
d0032700d1
|
Add tests for Pythia, GPT-JT, and RedPajama models
|
2023-09-13 01:10:39 -07:00 |
|
Tri Dao
|
08c295c043
|
Bump to v2.2.2
|
2023-09-10 23:48:12 -07:00 |
|
Tri Dao
|
ee77b931b9
|
Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
|
2023-09-10 22:56:33 -07:00 |
|
Kevin Hu
|
07005806ff
|
Add BigCode converters (#532)
|
2023-09-10 17:24:50 -07:00 |
|
Tri Dao
|
8a733cbd53
|
[Gen] Fix calling update_graph_cache in tests
|
2023-09-10 17:22:37 -07:00 |
|
Kevin Hu
|
4c91621a5e
|
Inverse state dict for BERT (#527)
|
2023-09-09 01:44:21 -07:00 |
|
Tri Dao
|
a86442f0f3
|
[Gen] Use flash_attn_with_kvcache in generation
|
2023-09-07 08:24:43 -07:00 |
|
Tri Dao
|
a1576ad1e8
|
Bump to v2.2.1
|
2023-09-06 02:19:55 -07:00 |
|
Tri Dao
|
9795159082
|
[Rotary] Set device before launching Triton kernel to avoid error
|
2023-09-05 21:29:03 -07:00 |
|
Tri Dao
|
6d673cd961
|
Bump to v2.2.0
|
2023-09-05 11:34:13 -07:00 |
|
Kyeongpil Kang
|
8e893f0950
|
Create __init__.py for ops/triton dir (#516)
|
2023-09-05 11:29:03 -07:00 |
|
Tri Dao
|
fd20f16a4e
|
Support cache_seqlens being integer
|
2023-09-05 11:27:48 -07:00 |
|
Tri Dao
|
913922cac5
|
[Gen] Refactor decoding function
|
2023-09-04 17:01:38 -07:00 |
|
Tri Dao
|
3557e0bb8f
|
[MLP] Implement SwiGLU with torch jiterator
|
2023-09-04 15:43:53 -07:00 |
|
Tri Dao
|
37c6e05406
|
Implement flash_attn_with_kvcache
|
2023-09-04 00:11:44 -07:00 |
|
Tri Dao
|
4976650f74
|
Set single threaded compilation for CUDA 12.2 so CI doesn't OOM
|
2023-09-03 23:42:55 -07:00 |
|
Tri Dao
|
6a89b2f121
|
Remove constexpr in launch template to fix CI compilation
|
2023-09-03 22:59:41 -07:00 |
|
Tri Dao
|
97ba7a62e9
|
Try switching back to Cutlass 3.2.0
|
2023-09-03 22:45:35 -07:00 |
|
Tri Dao
|
1dc1b6c8f2
|
Bump to v2.1.2
|
2023-09-03 22:23:05 -07:00 |
|
Tri Dao
|
798858f9f1
|
Fix test_baichuan
|
2023-09-03 21:01:37 -07:00 |
|
Tri Dao
|
7b33743a72
|
[Gen] Add back num_last_tokens in gpt.py
|
2023-09-03 20:44:40 -07:00 |
|
Tri Dao
|
b28ec236df
|
[Rotary] Implement varlen rotary
|
2023-09-03 17:57:10 -07:00 |
|
Tri Dao
|
861c82577d
|
[Rotary] Clean up rotary Triton implementation a bit
|
2023-09-03 16:41:17 -07:00 |
|
Tri Dao
|
1c523c1ce1
|
[Rotary] Speed up rotary kernel when interleaved=True
|
2023-09-03 16:24:37 -07:00 |
|
Tri Dao
|
de2949f37d
|
[Rotary] Pass max_seqlen from mha.py to rotary during inference
|
2023-09-03 11:37:06 -07:00 |
|
Tri Dao
|
942fcbf046
|
[Rotary] Implement rotary in Triton
|
2023-09-03 02:51:58 -07:00 |
|
dan_the_3rd
|
c9d4a816fa
|
Support LLaMa2 and CodeLLaMa (#491)
Co-authored-by: danthe3rd <danthe3rd>
|
2023-08-30 10:31:14 -07:00 |
|
dan_the_3rd
|
011ec323d6
|
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
|
2023-08-30 10:29:54 -07:00 |
|
GAOXinyu
|
0cb595ad94
|
[bugfix] handle_x not define when using checkpoint_lvl = 2 (#502)
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
|
2023-08-29 23:46:10 -07:00 |
|
Tri Dao
|
8a326bbc9e
|
[Gen] Minor fix to modify logits for top_p
|
2023-08-29 14:29:06 -07:00 |
|
Su Zhu
|
8f6f48d8a8
|
add unpad_input_for_concatenated_sequences (#499)
* add unpad_input_for_concatenated_sequences
* modify docstring
|
2023-08-29 02:23:56 -07:00 |
|
Tri Dao
|
757058d4d3
|
Update Cutlass to v3.2.0
|
2023-08-27 23:47:28 -07:00 |
|
Tri Dao
|
9f42cb6e7a
|
[Gen] Clone logits before returning when cg=True
|
2023-08-27 23:19:58 -07:00 |
|
Tri Dao
|
f8aea6ead0
|
[GPT] Generalize last_token_only arg to num_last_tokens
|
2023-08-26 20:47:53 -07:00 |
|
Tri Dao
|
7a3bd55f1a
|
[Gen] Fix decode function not using top_p during iterative decoding
|
2023-08-26 15:14:41 -07:00 |
|