Commit Graph

213 Commits

Author SHA1 Message Date
Tri Dao
6d673cd961 Bump to v2.2.0 2023-09-05 11:34:13 -07:00
Kyeongpil Kang
8e893f0950
Create __init__.py for ops/triton dir (#516) 2023-09-05 11:29:03 -07:00
Tri Dao
fd20f16a4e Support cache_seqlens being integer 2023-09-05 11:27:48 -07:00
Tri Dao
913922cac5 [Gen] Refactor decoding function 2023-09-04 17:01:38 -07:00
Tri Dao
3557e0bb8f [MLP] Implement SwiGLU with torch jiterator 2023-09-04 15:43:53 -07:00
Tri Dao
37c6e05406 Implement flash_attn_with_kvcache 2023-09-04 00:11:44 -07:00
Tri Dao
4976650f74 Set single threaded compilation for CUDA 12.2 so CI doesn't OOM 2023-09-03 23:42:55 -07:00
Tri Dao
6a89b2f121 Remove constexpr in launch template to fix CI compilation 2023-09-03 22:59:41 -07:00
Tri Dao
97ba7a62e9 Try switching back to Cutlass 3.2.0 2023-09-03 22:45:35 -07:00
Tri Dao
1dc1b6c8f2 Bump to v2.1.2 2023-09-03 22:23:05 -07:00
Tri Dao
798858f9f1 Fix test_baichuan 2023-09-03 21:01:37 -07:00
Tri Dao
7b33743a72 [Gen] Add back num_last_tokens in gpt.py 2023-09-03 20:44:40 -07:00
Tri Dao
b28ec236df [Rotary] Implement varlen rotary 2023-09-03 17:57:10 -07:00
Tri Dao
861c82577d [Rotary] Clean up rotary Triton implementation a bit 2023-09-03 16:41:17 -07:00
Tri Dao
1c523c1ce1 [Rotary] Speed up rotary kernel when interleaved=True 2023-09-03 16:24:37 -07:00
Tri Dao
de2949f37d [Rotary] Pass max_seqlen from mha.py to rotary during inference 2023-09-03 11:37:06 -07:00
Tri Dao
942fcbf046 [Rotary] Implement rotary in Triton 2023-09-03 02:51:58 -07:00
dan_the_3rd
c9d4a816fa
Support LLaMa2 and CodeLLaMa (#491)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:31:14 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
GAOXinyu
0cb595ad94
[bugfix] handle_x not define when using checkpoint_lvl = 2 (#502)
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
2023-08-29 23:46:10 -07:00
Tri Dao
8a326bbc9e [Gen] Minor fix to modify logits for top_p 2023-08-29 14:29:06 -07:00
Su Zhu
8f6f48d8a8
add unpad_input_for_concatenated_sequences (#499)
* add unpad_input_for_concatenated_sequences

* modify docstring
2023-08-29 02:23:56 -07:00
Tri Dao
757058d4d3 Update Cutlass to v3.2.0 2023-08-27 23:47:28 -07:00
Tri Dao
9f42cb6e7a [Gen] Clone logits before returning when cg=True 2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0 [GPT] Generalize last_token_only arg to num_last_tokens 2023-08-26 20:47:53 -07:00
Tri Dao
7a3bd55f1a [Gen] Fix decode function not using top_p during iterative decoding 2023-08-26 15:14:41 -07:00
Tri Dao
847abe653c [Gen] Refactor decode function a bit 2023-08-26 14:47:25 -07:00
Tri Dao
a2974e850a Change causal for CrossAttention in mha.py to align to bottom right 2023-08-26 12:57:33 -07:00
Tri Dao
73bd3f3bbb Move pyproject.toml to flash-attn and tests dir to avoid PEP 517 2023-08-25 15:05:28 -07:00
Tri Dao
9e5e8bc91e Change causal mask to be aligned to bottom-right instead of top-left 2023-08-24 23:41:07 -07:00
Aman Gupta Karmani
e0b09891c6
add llama support to GPTPreTrainedModel.from_pretrained (#479) 2023-08-24 16:31:16 -07:00
Tri Dao
6711b3bc40 Bump version to 2.0.9 2023-08-22 00:21:14 -07:00
Tri Dao
ef6d8c75d9 [GPT] Fix loading weights from HF hub 2023-08-21 22:56:02 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425) 2023-08-21 11:05:06 -07:00
Xuechen Li
25d6b1dbcb
handle uneven heads across ranks when combining state_dicts; resolves #467 (#468)
* q

* add comment.
2023-08-20 14:57:34 -07:00
Tri Dao
d431f16751 Import torch before flash_attn_2_cuda 2023-08-19 21:07:33 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format (#465)
* fix name.

* set inv function.

* add map back function.

* handle gqa.

* add type annotation to avoid confusion.

* fix docstr.

* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
ada4710d70 [ViT] Run black on vit.py 2023-08-17 17:45:09 -07:00
Tri Dao
a81900d4c1 [ViT] Minor fix so it runs 2023-08-17 17:25:34 -07:00
Tri Dao
4b661a569d [GPT] Run black on gpt.py 2023-08-16 23:47:50 -07:00
Tri Dao
bec5b3d374 [MHA] Run black on mha.py 2023-08-16 23:47:13 -07:00
Tri Dao
cb0daccc41 [FusedDense] Allow Row/ColumnParallelLinear to have uneven split 2023-08-16 23:43:35 -07:00
Tri Dao
bcfa7c9751 [FusedDense] Run black on fused_dense.py 2023-08-16 23:41:36 -07:00
Tri Dao
c65b5106ac Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal 2023-08-16 15:12:36 -07:00
Xuechen Li
0f7853c6a1
enable loading hf llama checkpoints for training (#446)
* prelim.

* add hf convertion fn.

* mlp.

* change name.

* fix bug.

* inverse permute.

* change comment.

* revert style changes.

* fix.

* add doc.

* revert.

* enable load safe.

* fix safe load.

* fix import.

* fix typing-related lints.

* fix ckpt loading logic.

* make single gpu work.

* test with parallel.

* ckpt format.

* enable pretrained state dict.

* remove unused imports.

* remove unused.

* mark idea related.
2023-08-15 08:33:15 -07:00
Tri Dao
c60851a825 Bump to v2.0.7 2023-08-14 14:55:35 -07:00
Tri Dao
f8dccfc90a [CI] Fix MATRIX_CUDA_VERSION check 2023-08-14 10:27:26 -07:00
Tri Dao
9c531bdc0a Use single thread compilation for cuda12.1, torch2.1 to avoid OOM CI 2023-08-14 10:03:31 -07:00