Commit Graph

60 Commits

Author SHA1 Message Date
Tri Dao
ccbb14f38e Implement rotary embedding in flash_attn_with_kvcache 2023-09-16 01:20:16 -07:00
Tri Dao
d0032700d1 Add tests for Pythia, GPT-JT, and RedPajama models 2023-09-13 01:10:39 -07:00
Kevin Hu
07005806ff
Add BigCode converters (#532) 2023-09-10 17:24:50 -07:00
Tri Dao
8a733cbd53 [Gen] Fix calling update_graph_cache in tests 2023-09-10 17:22:37 -07:00
Kevin Hu
4c91621a5e
Inverse state dict for BERT (#527) 2023-09-09 01:44:21 -07:00
Tri Dao
a86442f0f3 [Gen] Use flash_attn_with_kvcache in generation 2023-09-07 08:24:43 -07:00
Tri Dao
9795159082 [Rotary] Set device before launching Triton kernel to avoid error 2023-09-05 21:29:03 -07:00
Tri Dao
fd20f16a4e Support cache_seqlens being integer 2023-09-05 11:27:48 -07:00
Tri Dao
913922cac5 [Gen] Refactor decoding function 2023-09-04 17:01:38 -07:00
Tri Dao
798858f9f1 Fix test_baichuan 2023-09-03 21:01:37 -07:00
Tri Dao
942fcbf046 [Rotary] Implement rotary in Triton 2023-09-03 02:51:58 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
Tri Dao
9f42cb6e7a [Gen] Clone logits before returning when cg=True 2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0 [GPT] Generalize last_token_only arg to num_last_tokens 2023-08-26 20:47:53 -07:00
Tri Dao
371e20658c [GPT] Test generation when passing in multiple tokens 2023-08-26 13:56:41 -07:00
Tri Dao
c000c3a2c0 [GPT] Move more tests to test_gpt.py 2023-08-26 13:00:40 -07:00
Tri Dao
9b713872ea [GPT] Move GPT and OPT generation tests to test_{gpt,opt}.py 2023-08-26 12:55:02 -07:00
Tri Dao
ef6d8c75d9 [GPT] Fix loading weights from HF hub 2023-08-21 22:56:02 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425) 2023-08-21 11:05:06 -07:00
Tri Dao
0e8c46ae08 Run isort and black on test files 2023-08-18 20:59:35 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format (#465)
* fix name.

* set inv function.

* add map back function.

* handle gqa.

* add type annotation to avoid confusion.

* fix docstr.

* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
a81900d4c1 [ViT] Minor fix so it runs 2023-08-17 17:25:34 -07:00
Xuechen Li
0f7853c6a1
enable loading hf llama checkpoints for training (#446)
* prelim.

* add hf convertion fn.

* mlp.

* change name.

* fix bug.

* inverse permute.

* change comment.

* revert style changes.

* fix.

* add doc.

* revert.

* enable load safe.

* fix safe load.

* fix import.

* fix typing-related lints.

* fix ckpt loading logic.

* make single gpu work.

* test with parallel.

* ckpt format.

* enable pretrained state dict.

* remove unused imports.

* remove unused.

* mark idea related.
2023-08-15 08:33:15 -07:00
Tri Dao
184b992dcb [GPT] Implement parallel LLaMa 2023-07-28 15:52:48 -10:00
Tri Dao
56ccaff126 [GPT] Add LLaMa-13B to test 2023-07-26 07:22:22 -10:00
Tri Dao
8e9820a55b [Rotary] Fix tests when loading state dict with rotary inv_freqs 2023-07-26 07:16:33 -10:00
Tri Dao
d38357dd2f [GPT] Implement Falcon 2023-07-23 10:32:29 -07:00
Tri Dao
425dbcb6c6 [MHA] Implement MQA/GQA 2023-07-23 00:06:58 -07:00
Tri Dao
b3177dfaf6 [GPT] Enable FlashAttention for GPT-J 2023-07-21 17:29:10 -07:00
Tri Dao
62e9814466 [Rotary] Make sure frequency calculation is in fp32 2023-07-02 16:39:39 -07:00
Tri Dao
48bc6eacd6 [Gen] Add rotary base as an argument to FT attention kernel 2023-05-30 13:38:34 -07:00
Tri Dao
a9a4b4e4f2 [LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm 2023-05-04 23:39:43 -07:00
Tri Dao
311d6606bf [Gen] Fix FT kernel smem size, CG when batch size changed 2023-04-20 17:03:13 -07:00
Tri Dao
96d10f6545 Implement LLaMa 2023-04-18 21:51:35 -07:00
Tri Dao
605655bc66 [Gen] Fix FT kernel when using CG 2023-04-14 16:50:01 -07:00
Tri Dao
393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
Tri Dao
993d12448e Implement GPT-NeoX 2023-03-29 01:21:25 -07:00
Tri Dao
f5d0fbd468 [FT] Fix FT's single query attention for bf16 hdim128 rotary 2023-03-28 21:27:00 -07:00
Tri Dao
4d87e4d875 Implement GPT-J 2023-03-22 16:16:58 -07:00
Tri Dao
78b7a1dc18 [OPT] Load fp16 weights on CPU before moving to GPU 2023-01-22 17:01:32 -08:00
Tri Dao
f68d41ec77 [Gen] Add OPT to generation test 2023-01-17 19:59:06 -08:00
Tri Dao
88173a1aaf [FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP 2023-01-17 18:12:27 -08:00
Tri Dao
780e8eeabb [ViT] Support timm checkpoint, add tests 2023-01-16 01:20:34 -08:00
Tri Dao
ff34123bd4 Reorder LN in Block, support OPT 2023-01-15 22:14:31 -08:00
Tri Dao
f1e01c27ba [Gen] Pass qkv_stride to ft_attention kernel for batched generation 2023-01-15 15:20:01 -08:00
Tri Dao
7c2191542a [Gen] Make generation work with Tensor Parallel 2023-01-15 11:34:27 -08:00
Tri Dao
b48599002a [Gen] Add timing option 2023-01-07 19:05:09 -08:00
Tri Dao
0938298e4c [Gen] Adjust shape of kv_cache when using FT 2023-01-07 17:27:54 -08:00
Tri Dao
e02fd588aa [Gen] Implement top-k and top-p sampling 2023-01-07 17:00:02 -08:00