Tri Dao
|
320fb59487
|
Update citation
|
2024-05-26 16:09:03 -07:00 |
|
Tri Dao
|
0a146185d6
|
[Gen] Remove minor dead code
|
2023-12-19 22:57:39 -08:00 |
|
Tri Dao
|
79bd1a2d5d
|
[LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton
|
2023-11-13 02:04:49 -08:00 |
|
Tri Dao
|
e0fbaa7016
|
[Gen] Simplify decode_speculative
|
2023-09-19 22:20:22 -07:00 |
|
Tri Dao
|
e6a8026489
|
[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset
|
2023-09-19 22:20:22 -07:00 |
|
Tri Dao
|
dfe29f5e2b
|
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
|
2023-09-18 15:29:06 -07:00 |
|
Tri Dao
|
a86442f0f3
|
[Gen] Use flash_attn_with_kvcache in generation
|
2023-09-07 08:24:43 -07:00 |
|
Tri Dao
|
913922cac5
|
[Gen] Refactor decoding function
|
2023-09-04 17:01:38 -07:00 |
|
Tri Dao
|
37c6e05406
|
Implement flash_attn_with_kvcache
|
2023-09-04 00:11:44 -07:00 |
|
Tri Dao
|
8a326bbc9e
|
[Gen] Minor fix to modify logits for top_p
|
2023-08-29 14:29:06 -07:00 |
|
Tri Dao
|
9f42cb6e7a
|
[Gen] Clone logits before returning when cg=True
|
2023-08-27 23:19:58 -07:00 |
|
Tri Dao
|
f8aea6ead0
|
[GPT] Generalize last_token_only arg to num_last_tokens
|
2023-08-26 20:47:53 -07:00 |
|
Tri Dao
|
7a3bd55f1a
|
[Gen] Fix decode function not using top_p during iterative decoding
|
2023-08-26 15:14:41 -07:00 |
|
Tri Dao
|
847abe653c
|
[Gen] Refactor decode function a bit
|
2023-08-26 14:47:25 -07:00 |
|
Tri Dao
|
ef6d8c75d9
|
[GPT] Fix loading weights from HF hub
|
2023-08-21 22:56:02 -07:00 |
|
Tri Dao
|
f1a73d0740
|
Run isort and black on python files
|
2023-08-18 14:22:11 -07:00 |
|
Xuechen Li
|
bb4cded17b
|
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.
* trim.
* enable passing in number of heads for each rank.
* simplify.
* simplify.
* cleanup.
* fix col parallel.
* fix bug with row parallel.
* fit out proj.
* refac.
* fix sharding logic.
* refac sharding.
* refac.
* support multiple of.
* make fn reuseable.
* fix bug in dimensions.
* scaffold.
* test uneven heads.
* fix test by adding barrier.
* refac.
* reuse code.
* clean up.
|
2023-08-18 14:10:35 -07:00 |
|
Xuechen Li
|
0f7853c6a1
|
enable loading hf llama checkpoints for training (#446)
* prelim.
* add hf convertion fn.
* mlp.
* change name.
* fix bug.
* inverse permute.
* change comment.
* revert style changes.
* fix.
* add doc.
* revert.
* enable load safe.
* fix safe load.
* fix import.
* fix typing-related lints.
* fix ckpt loading logic.
* make single gpu work.
* test with parallel.
* ckpt format.
* enable pretrained state dict.
* remove unused imports.
* remove unused.
* mark idea related.
|
2023-08-15 08:33:15 -07:00 |
|
Tri Dao
|
60499abcfd
|
[Benchmark] Add script to benchmark FlashAttention
|
2023-07-28 00:26:52 -10:00 |
|
Tri Dao
|
fcab93b43a
|
[Gen] Minor tweak to allocate_inference_cache
|
2023-04-21 11:56:47 -07:00 |
|
Tri Dao
|
ba2fe7f378
|
[Gen] Move allocate_inference_cache to within the model
|
2023-04-20 18:15:12 -07:00 |
|
Tri Dao
|
3da42d24b1
|
[GPT] Add option to only return the logit for the last token
|
2023-04-20 17:21:08 -07:00 |
|
Tri Dao
|
311d6606bf
|
[Gen] Fix FT kernel smem size, CG when batch size changed
|
2023-04-20 17:03:13 -07:00 |
|
Tri Dao
|
605655bc66
|
[Gen] Fix FT kernel when using CG
|
2023-04-14 16:50:01 -07:00 |
|
Tri Dao
|
1c9ef9b399
|
[Gen] Measure prompt processing + decoding time, not just decoding
|
2023-04-13 15:39:56 -07:00 |
|
Tri Dao
|
f5d0fbd468
|
[FT] Fix FT's single query attention for bf16 hdim128 rotary
|
2023-03-28 21:27:00 -07:00 |
|
Tri Dao
|
4d87e4d875
|
Implement GPT-J
|
2023-03-22 16:16:58 -07:00 |
|
Tri Dao
|
78b7a1dc18
|
[OPT] Load fp16 weights on CPU before moving to GPU
|
2023-01-22 17:01:32 -08:00 |
|
Tri Dao
|
f68d41ec77
|
[Gen] Add OPT to generation test
|
2023-01-17 19:59:06 -08:00 |
|
Tri Dao
|
7c2191542a
|
[Gen] Make generation work with Tensor Parallel
|
2023-01-15 11:34:27 -08:00 |
|
Tri Dao
|
f95c2fc108
|
[Gen] Remove commented code
|
2023-01-07 19:06:39 -08:00 |
|
Tri Dao
|
b48599002a
|
[Gen] Add timing option
|
2023-01-07 19:05:09 -08:00 |
|
Tri Dao
|
e02fd588aa
|
[Gen] Implement top-k and top-p sampling
|
2023-01-07 17:00:02 -08:00 |
|
Tri Dao
|
11be742aa3
|
[Gen] Test generation with rotary embedding
|
2023-01-07 14:37:54 -08:00 |
|
Tri Dao
|
93383bd55b
|
[TP] Implement TensorParallel without sequence parallel
|
2023-01-07 13:45:22 -08:00 |
|
Tri Dao
|
a668890fcd
|
[Gen] Add option to run generation with FT attention kernel
|
2023-01-03 22:10:31 -08:00 |
|
Tri Dao
|
a6ec1782dc
|
Bump to v0.2.6
|
2022-12-27 22:05:20 -08:00 |
|
Tri Dao
|
63670fd84a
|
Implement generation for GPT
|
2022-12-27 21:01:50 -08:00 |
|
Tri Dao
|
c6ecd40a59
|
Tweak CrossEntropyLoss to take process_group in init
|
2022-12-27 10:47:43 -08:00 |
|
Tri Dao
|
b4018a5028
|
Implement Tensor Parallel for GPT model
|
2022-12-26 16:22:43 -08:00 |
|
Tri Dao
|
226a1b721d
|
Implement TensorParallel for FusedDense and FusedDenseGeluDense
|
2022-12-24 11:48:56 -08:00 |
|
Tri Dao
|
ece539abd6
|
Add __init__.py files to subdirectories for installation
|
2022-11-17 16:55:44 -08:00 |
|
Tri Dao
|
fb88e5e4b3
|
Move benchmark utils, support AMP
|
2022-10-23 12:50:00 -07:00 |
|