Tri Dao
|
6d673cd961
|
Bump to v2.2.0
|
2023-09-05 11:34:13 -07:00 |
|
Kyeongpil Kang
|
8e893f0950
|
Create __init__.py for ops/triton dir (#516)
|
2023-09-05 11:29:03 -07:00 |
|
Tri Dao
|
fd20f16a4e
|
Support cache_seqlens being integer
|
2023-09-05 11:27:48 -07:00 |
|
Tri Dao
|
913922cac5
|
[Gen] Refactor decoding function
|
2023-09-04 17:01:38 -07:00 |
|
Tri Dao
|
3557e0bb8f
|
[MLP] Implement SwiGLU with torch jiterator
|
2023-09-04 15:43:53 -07:00 |
|
Tri Dao
|
37c6e05406
|
Implement flash_attn_with_kvcache
|
2023-09-04 00:11:44 -07:00 |
|
Tri Dao
|
4976650f74
|
Set single threaded compilation for CUDA 12.2 so CI doesn't OOM
|
2023-09-03 23:42:55 -07:00 |
|
Tri Dao
|
6a89b2f121
|
Remove constexpr in launch template to fix CI compilation
|
2023-09-03 22:59:41 -07:00 |
|
Tri Dao
|
97ba7a62e9
|
Try switching back to Cutlass 3.2.0
|
2023-09-03 22:45:35 -07:00 |
|
Tri Dao
|
1dc1b6c8f2
|
Bump to v2.1.2
|
2023-09-03 22:23:05 -07:00 |
|
Tri Dao
|
798858f9f1
|
Fix test_baichuan
|
2023-09-03 21:01:37 -07:00 |
|
Tri Dao
|
7b33743a72
|
[Gen] Add back num_last_tokens in gpt.py
|
2023-09-03 20:44:40 -07:00 |
|
Tri Dao
|
b28ec236df
|
[Rotary] Implement varlen rotary
|
2023-09-03 17:57:10 -07:00 |
|
Tri Dao
|
861c82577d
|
[Rotary] Clean up rotary Triton implementation a bit
|
2023-09-03 16:41:17 -07:00 |
|
Tri Dao
|
1c523c1ce1
|
[Rotary] Speed up rotary kernel when interleaved=True
|
2023-09-03 16:24:37 -07:00 |
|
Tri Dao
|
de2949f37d
|
[Rotary] Pass max_seqlen from mha.py to rotary during inference
|
2023-09-03 11:37:06 -07:00 |
|
Tri Dao
|
942fcbf046
|
[Rotary] Implement rotary in Triton
|
2023-09-03 02:51:58 -07:00 |
|
dan_the_3rd
|
c9d4a816fa
|
Support LLaMa2 and CodeLLaMa (#491)
Co-authored-by: danthe3rd <danthe3rd>
|
2023-08-30 10:31:14 -07:00 |
|
dan_the_3rd
|
011ec323d6
|
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
|
2023-08-30 10:29:54 -07:00 |
|
GAOXinyu
|
0cb595ad94
|
[bugfix] handle_x not define when using checkpoint_lvl = 2 (#502)
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
|
2023-08-29 23:46:10 -07:00 |
|
Tri Dao
|
8a326bbc9e
|
[Gen] Minor fix to modify logits for top_p
|
2023-08-29 14:29:06 -07:00 |
|
Su Zhu
|
8f6f48d8a8
|
add unpad_input_for_concatenated_sequences (#499)
* add unpad_input_for_concatenated_sequences
* modify docstring
|
2023-08-29 02:23:56 -07:00 |
|
Tri Dao
|
757058d4d3
|
Update Cutlass to v3.2.0
|
2023-08-27 23:47:28 -07:00 |
|
Tri Dao
|
9f42cb6e7a
|
[Gen] Clone logits before returning when cg=True
|
2023-08-27 23:19:58 -07:00 |
|
Tri Dao
|
f8aea6ead0
|
[GPT] Generalize last_token_only arg to num_last_tokens
|
2023-08-26 20:47:53 -07:00 |
|
Tri Dao
|
7a3bd55f1a
|
[Gen] Fix decode function not using top_p during iterative decoding
|
2023-08-26 15:14:41 -07:00 |
|
Tri Dao
|
847abe653c
|
[Gen] Refactor decode function a bit
|
2023-08-26 14:47:25 -07:00 |
|
Tri Dao
|
a2974e850a
|
Change causal for CrossAttention in mha.py to align to bottom right
|
2023-08-26 12:57:33 -07:00 |
|
Tri Dao
|
73bd3f3bbb
|
Move pyproject.toml to flash-attn and tests dir to avoid PEP 517
|
2023-08-25 15:05:28 -07:00 |
|
Tri Dao
|
9e5e8bc91e
|
Change causal mask to be aligned to bottom-right instead of top-left
|
2023-08-24 23:41:07 -07:00 |
|
Aman Gupta Karmani
|
e0b09891c6
|
add llama support to GPTPreTrainedModel.from_pretrained (#479)
|
2023-08-24 16:31:16 -07:00 |
|
Tri Dao
|
6711b3bc40
|
Bump version to 2.0.9
|
2023-08-22 00:21:14 -07:00 |
|
Tri Dao
|
ef6d8c75d9
|
[GPT] Fix loading weights from HF hub
|
2023-08-21 22:56:02 -07:00 |
|
GAOXinyu
|
a8c35b4f57
|
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425)
|
2023-08-21 11:05:06 -07:00 |
|
Xuechen Li
|
25d6b1dbcb
|
handle uneven heads across ranks when combining state_dicts; resolves #467 (#468)
* q
* add comment.
|
2023-08-20 14:57:34 -07:00 |
|
Tri Dao
|
d431f16751
|
Import torch before flash_attn_2_cuda
|
2023-08-19 21:07:33 -07:00 |
|
Xuechen Li
|
7fcd3e6a04
|
map custom model state_dict back to huggingface format (#465)
* fix name.
* set inv function.
* add map back function.
* handle gqa.
* add type annotation to avoid confusion.
* fix docstr.
* test inverse remap logic.
|
2023-08-18 20:51:39 -07:00 |
|
Tri Dao
|
f1a73d0740
|
Run isort and black on python files
|
2023-08-18 14:22:11 -07:00 |
|
Xuechen Li
|
bb4cded17b
|
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.
* trim.
* enable passing in number of heads for each rank.
* simplify.
* simplify.
* cleanup.
* fix col parallel.
* fix bug with row parallel.
* fit out proj.
* refac.
* fix sharding logic.
* refac sharding.
* refac.
* support multiple of.
* make fn reuseable.
* fix bug in dimensions.
* scaffold.
* test uneven heads.
* fix test by adding barrier.
* refac.
* reuse code.
* clean up.
|
2023-08-18 14:10:35 -07:00 |
|
Tri Dao
|
ada4710d70
|
[ViT] Run black on vit.py
|
2023-08-17 17:45:09 -07:00 |
|
Tri Dao
|
a81900d4c1
|
[ViT] Minor fix so it runs
|
2023-08-17 17:25:34 -07:00 |
|
Tri Dao
|
4b661a569d
|
[GPT] Run black on gpt.py
|
2023-08-16 23:47:50 -07:00 |
|
Tri Dao
|
bec5b3d374
|
[MHA] Run black on mha.py
|
2023-08-16 23:47:13 -07:00 |
|
Tri Dao
|
cb0daccc41
|
[FusedDense] Allow Row/ColumnParallelLinear to have uneven split
|
2023-08-16 23:43:35 -07:00 |
|
Tri Dao
|
bcfa7c9751
|
[FusedDense] Run black on fused_dense.py
|
2023-08-16 23:41:36 -07:00 |
|
Tri Dao
|
c65b5106ac
|
Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal
|
2023-08-16 15:12:36 -07:00 |
|
Xuechen Li
|
0f7853c6a1
|
enable loading hf llama checkpoints for training (#446)
* prelim.
* add hf convertion fn.
* mlp.
* change name.
* fix bug.
* inverse permute.
* change comment.
* revert style changes.
* fix.
* add doc.
* revert.
* enable load safe.
* fix safe load.
* fix import.
* fix typing-related lints.
* fix ckpt loading logic.
* make single gpu work.
* test with parallel.
* ckpt format.
* enable pretrained state dict.
* remove unused imports.
* remove unused.
* mark idea related.
|
2023-08-15 08:33:15 -07:00 |
|
Tri Dao
|
c60851a825
|
Bump to v2.0.7
|
2023-08-14 14:55:35 -07:00 |
|
Tri Dao
|
f8dccfc90a
|
[CI] Fix MATRIX_CUDA_VERSION check
|
2023-08-14 10:27:26 -07:00 |
|
Tri Dao
|
9c531bdc0a
|
Use single thread compilation for cuda12.1, torch2.1 to avoid OOM CI
|
2023-08-14 10:03:31 -07:00 |
|