Commit Graph

43 Commits

Author SHA1 Message Date
Tri Dao
320fb59487 Update citation 2024-05-26 16:09:03 -07:00
Tri Dao
0a146185d6 [Gen] Remove minor dead code 2023-12-19 22:57:39 -08:00
Tri Dao
79bd1a2d5d [LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton 2023-11-13 02:04:49 -08:00
Tri Dao
e0fbaa7016 [Gen] Simplify decode_speculative 2023-09-19 22:20:22 -07:00
Tri Dao
e6a8026489 [Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset 2023-09-19 22:20:22 -07:00
Tri Dao
dfe29f5e2b [Gen] Don't use ft_attention, use flash_attn_with_kvcache instead 2023-09-18 15:29:06 -07:00
Tri Dao
a86442f0f3 [Gen] Use flash_attn_with_kvcache in generation 2023-09-07 08:24:43 -07:00
Tri Dao
913922cac5 [Gen] Refactor decoding function 2023-09-04 17:01:38 -07:00
Tri Dao
37c6e05406 Implement flash_attn_with_kvcache 2023-09-04 00:11:44 -07:00
Tri Dao
8a326bbc9e [Gen] Minor fix to modify logits for top_p 2023-08-29 14:29:06 -07:00
Tri Dao
9f42cb6e7a [Gen] Clone logits before returning when cg=True 2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0 [GPT] Generalize last_token_only arg to num_last_tokens 2023-08-26 20:47:53 -07:00
Tri Dao
7a3bd55f1a [Gen] Fix decode function not using top_p during iterative decoding 2023-08-26 15:14:41 -07:00
Tri Dao
847abe653c [Gen] Refactor decode function a bit 2023-08-26 14:47:25 -07:00
Tri Dao
ef6d8c75d9 [GPT] Fix loading weights from HF hub 2023-08-21 22:56:02 -07:00
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
2023-08-18 14:10:35 -07:00
Xuechen Li
0f7853c6a1
enable loading hf llama checkpoints for training (#446)
* prelim.

* add hf convertion fn.

* mlp.

* change name.

* fix bug.

* inverse permute.

* change comment.

* revert style changes.

* fix.

* add doc.

* revert.

* enable load safe.

* fix safe load.

* fix import.

* fix typing-related lints.

* fix ckpt loading logic.

* make single gpu work.

* test with parallel.

* ckpt format.

* enable pretrained state dict.

* remove unused imports.

* remove unused.

* mark idea related.
2023-08-15 08:33:15 -07:00
Tri Dao
60499abcfd [Benchmark] Add script to benchmark FlashAttention 2023-07-28 00:26:52 -10:00
Tri Dao
fcab93b43a [Gen] Minor tweak to allocate_inference_cache 2023-04-21 11:56:47 -07:00
Tri Dao
ba2fe7f378 [Gen] Move allocate_inference_cache to within the model 2023-04-20 18:15:12 -07:00
Tri Dao
3da42d24b1 [GPT] Add option to only return the logit for the last token 2023-04-20 17:21:08 -07:00
Tri Dao
311d6606bf [Gen] Fix FT kernel smem size, CG when batch size changed 2023-04-20 17:03:13 -07:00
Tri Dao
605655bc66 [Gen] Fix FT kernel when using CG 2023-04-14 16:50:01 -07:00
Tri Dao
1c9ef9b399 [Gen] Measure prompt processing + decoding time, not just decoding 2023-04-13 15:39:56 -07:00
Tri Dao
f5d0fbd468 [FT] Fix FT's single query attention for bf16 hdim128 rotary 2023-03-28 21:27:00 -07:00
Tri Dao
4d87e4d875 Implement GPT-J 2023-03-22 16:16:58 -07:00
Tri Dao
78b7a1dc18 [OPT] Load fp16 weights on CPU before moving to GPU 2023-01-22 17:01:32 -08:00
Tri Dao
f68d41ec77 [Gen] Add OPT to generation test 2023-01-17 19:59:06 -08:00
Tri Dao
7c2191542a [Gen] Make generation work with Tensor Parallel 2023-01-15 11:34:27 -08:00
Tri Dao
f95c2fc108 [Gen] Remove commented code 2023-01-07 19:06:39 -08:00
Tri Dao
b48599002a [Gen] Add timing option 2023-01-07 19:05:09 -08:00
Tri Dao
e02fd588aa [Gen] Implement top-k and top-p sampling 2023-01-07 17:00:02 -08:00
Tri Dao
11be742aa3 [Gen] Test generation with rotary embedding 2023-01-07 14:37:54 -08:00
Tri Dao
93383bd55b [TP] Implement TensorParallel without sequence parallel 2023-01-07 13:45:22 -08:00
Tri Dao
a668890fcd [Gen] Add option to run generation with FT attention kernel 2023-01-03 22:10:31 -08:00
Tri Dao
a6ec1782dc Bump to v0.2.6 2022-12-27 22:05:20 -08:00
Tri Dao
63670fd84a Implement generation for GPT 2022-12-27 21:01:50 -08:00
Tri Dao
c6ecd40a59 Tweak CrossEntropyLoss to take process_group in init 2022-12-27 10:47:43 -08:00
Tri Dao
b4018a5028 Implement Tensor Parallel for GPT model 2022-12-26 16:22:43 -08:00
Tri Dao
226a1b721d Implement TensorParallel for FusedDense and FusedDenseGeluDense 2022-12-24 11:48:56 -08:00
Tri Dao
ece539abd6 Add __init__.py files to subdirectories for installation 2022-11-17 16:55:44 -08:00
Tri Dao
fb88e5e4b3 Move benchmark utils, support AMP 2022-10-23 12:50:00 -07:00