Tri Dao
8a733cbd53
[Gen] Fix calling update_graph_cache in tests
2023-09-10 17:22:37 -07:00
Kevin Hu
4c91621a5e
Inverse state dict for BERT ( #527 )
2023-09-09 01:44:21 -07:00
Tri Dao
a86442f0f3
[Gen] Use flash_attn_with_kvcache in generation
2023-09-07 08:24:43 -07:00
Tri Dao
a1576ad1e8
Bump to v2.2.1
2023-09-06 02:19:55 -07:00
Tri Dao
9795159082
[Rotary] Set device before launching Triton kernel to avoid error
2023-09-05 21:29:03 -07:00
Tri Dao
6d673cd961
Bump to v2.2.0
2023-09-05 11:34:13 -07:00
Kyeongpil Kang
8e893f0950
Create __init__.py for ops/triton dir ( #516 )
2023-09-05 11:29:03 -07:00
Tri Dao
fd20f16a4e
Support cache_seqlens being integer
2023-09-05 11:27:48 -07:00
Tri Dao
913922cac5
[Gen] Refactor decoding function
2023-09-04 17:01:38 -07:00
Tri Dao
3557e0bb8f
[MLP] Implement SwiGLU with torch jiterator
2023-09-04 15:43:53 -07:00
Tri Dao
37c6e05406
Implement flash_attn_with_kvcache
2023-09-04 00:11:44 -07:00
Tri Dao
4976650f74
Set single threaded compilation for CUDA 12.2 so CI doesn't OOM
2023-09-03 23:42:55 -07:00
Tri Dao
6a89b2f121
Remove constexpr in launch template to fix CI compilation
2023-09-03 22:59:41 -07:00
Tri Dao
97ba7a62e9
Try switching back to Cutlass 3.2.0
2023-09-03 22:45:35 -07:00
Tri Dao
1dc1b6c8f2
Bump to v2.1.2
2023-09-03 22:23:05 -07:00
Tri Dao
0c04943fa2
Require CUDA 11.6+, clean up setup.py
2023-09-03 21:24:56 -07:00
Tri Dao
798858f9f1
Fix test_baichuan
2023-09-03 21:01:37 -07:00
Tri Dao
7b33743a72
[Gen] Add back num_last_tokens in gpt.py
2023-09-03 20:44:40 -07:00
Tri Dao
5953c4f58c
Remove unused sdPsum in dot_do_o function
2023-09-03 20:44:07 -07:00
Tri Dao
b28ec236df
[Rotary] Implement varlen rotary
2023-09-03 17:57:10 -07:00
Tri Dao
861c82577d
[Rotary] Clean up rotary Triton implementation a bit
2023-09-03 16:41:17 -07:00
Tri Dao
1c523c1ce1
[Rotary] Speed up rotary kernel when interleaved=True
2023-09-03 16:24:37 -07:00
Tri Dao
26d7d92f3d
Fix splitKV combine function when local LSEs are all -inf
2023-09-03 11:39:09 -07:00
Tri Dao
de2949f37d
[Rotary] Pass max_seqlen from mha.py to rotary during inference
2023-09-03 11:37:06 -07:00
Tri Dao
942fcbf046
[Rotary] Implement rotary in Triton
2023-09-03 02:51:58 -07:00
Tri Dao
08e9847176
[CI] Add CUDA 12.2
2023-09-03 02:45:42 -07:00
Sophia Wisdom
37e32febba
Remove commented out code in bwd ( #512 )
...
* Remove lots of comments
* Remove unused traits
2023-09-01 16:43:58 -07:00
Sophia Wisdom
dd8a754915
Remove old code in utils.h ( #511 )
2023-09-01 15:32:09 -07:00
Aman Gupta Karmani
866a9d33f9
bump cutlass submodule ( #504 )
2023-08-30 10:32:04 -07:00
dan_the_3rd
c9d4a816fa
Support LLaMa2 and CodeLLaMa ( #491 )
...
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:31:14 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding ( #490 )
...
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
GAOXinyu
0cb595ad94
[bugfix] handle_x not define when using checkpoint_lvl = 2 ( #502 )
...
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
2023-08-29 23:46:10 -07:00
Tri Dao
31920dda5f
Fix typo with lse_max == -INFINITY
2023-08-29 21:48:59 -07:00
Tri Dao
8a326bbc9e
[Gen] Minor fix to modify logits for top_p
2023-08-29 14:29:06 -07:00
Jeffrey Quesnelle
1d817a8ffc
fix citation in README ( #501 )
2023-08-29 11:15:33 -07:00
Su Zhu
8f6f48d8a8
add unpad_input_for_concatenated_sequences ( #499 )
...
* add unpad_input_for_concatenated_sequences
* modify docstring
2023-08-29 02:23:56 -07:00
Tri Dao
b1fbbd8337
Implement splitKV attention
2023-08-29 00:58:29 -07:00
Tri Dao
7a983df742
Use generate_kernels.py script from Driss Guessous
2023-08-28 13:34:12 -07:00
dan_the_3rd
c3f2a632aa
[ft_attention] Fix for seqlen=8136 ( #488 )
...
When seqlen=8136, `smem_sz = 48840`, and apparently starting the kernel returns an `invalid argument` CUDA error.
`48840 < 48 * 1024` but apparently it's still above the limit somehow..?
Tested on A100
2023-08-28 10:00:22 -07:00
Tri Dao
757058d4d3
Update Cutlass to v3.2.0
2023-08-27 23:47:28 -07:00
Tri Dao
9f42cb6e7a
[Gen] Clone logits before returning when cg=True
2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0
[GPT] Generalize last_token_only arg to num_last_tokens
2023-08-26 20:47:53 -07:00
Tri Dao
7a3bd55f1a
[Gen] Fix decode function not using top_p during iterative decoding
2023-08-26 15:14:41 -07:00
Tri Dao
847abe653c
[Gen] Refactor decode function a bit
2023-08-26 14:47:25 -07:00
Tri Dao
371e20658c
[GPT] Test generation when passing in multiple tokens
2023-08-26 13:56:41 -07:00
Tri Dao
c000c3a2c0
[GPT] Move more tests to test_gpt.py
2023-08-26 13:00:40 -07:00
Tri Dao
a2974e850a
Change causal for CrossAttention in mha.py to align to bottom right
2023-08-26 12:57:33 -07:00
Tri Dao
9b713872ea
[GPT] Move GPT and OPT generation tests to test_{gpt,opt}.py
2023-08-26 12:55:02 -07:00
Tri Dao
73bd3f3bbb
Move pyproject.toml to flash-attn and tests dir to avoid PEP 517
2023-08-25 15:05:28 -07:00
Aman Gupta Karmani
b4b6e90334
add benchmark for xformers fa2 wrapper ( #492 )
2023-08-25 14:10:05 -07:00