Tri Dao
|
ccbb14f38e
|
Implement rotary embedding in flash_attn_with_kvcache
|
2023-09-16 01:20:16 -07:00 |
|
Tri Dao
|
d0032700d1
|
Add tests for Pythia, GPT-JT, and RedPajama models
|
2023-09-13 01:10:39 -07:00 |
|
Kevin Hu
|
07005806ff
|
Add BigCode converters (#532)
|
2023-09-10 17:24:50 -07:00 |
|
Tri Dao
|
8a733cbd53
|
[Gen] Fix calling update_graph_cache in tests
|
2023-09-10 17:22:37 -07:00 |
|
Kevin Hu
|
4c91621a5e
|
Inverse state dict for BERT (#527)
|
2023-09-09 01:44:21 -07:00 |
|
Tri Dao
|
a86442f0f3
|
[Gen] Use flash_attn_with_kvcache in generation
|
2023-09-07 08:24:43 -07:00 |
|
Tri Dao
|
9795159082
|
[Rotary] Set device before launching Triton kernel to avoid error
|
2023-09-05 21:29:03 -07:00 |
|
Tri Dao
|
fd20f16a4e
|
Support cache_seqlens being integer
|
2023-09-05 11:27:48 -07:00 |
|
Tri Dao
|
913922cac5
|
[Gen] Refactor decoding function
|
2023-09-04 17:01:38 -07:00 |
|
Tri Dao
|
798858f9f1
|
Fix test_baichuan
|
2023-09-03 21:01:37 -07:00 |
|
Tri Dao
|
942fcbf046
|
[Rotary] Implement rotary in Triton
|
2023-09-03 02:51:58 -07:00 |
|
dan_the_3rd
|
011ec323d6
|
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
|
2023-08-30 10:29:54 -07:00 |
|
Tri Dao
|
9f42cb6e7a
|
[Gen] Clone logits before returning when cg=True
|
2023-08-27 23:19:58 -07:00 |
|
Tri Dao
|
f8aea6ead0
|
[GPT] Generalize last_token_only arg to num_last_tokens
|
2023-08-26 20:47:53 -07:00 |
|
Tri Dao
|
371e20658c
|
[GPT] Test generation when passing in multiple tokens
|
2023-08-26 13:56:41 -07:00 |
|
Tri Dao
|
c000c3a2c0
|
[GPT] Move more tests to test_gpt.py
|
2023-08-26 13:00:40 -07:00 |
|
Tri Dao
|
9b713872ea
|
[GPT] Move GPT and OPT generation tests to test_{gpt,opt}.py
|
2023-08-26 12:55:02 -07:00 |
|
Tri Dao
|
ef6d8c75d9
|
[GPT] Fix loading weights from HF hub
|
2023-08-21 22:56:02 -07:00 |
|
GAOXinyu
|
a8c35b4f57
|
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425)
|
2023-08-21 11:05:06 -07:00 |
|
Tri Dao
|
0e8c46ae08
|
Run isort and black on test files
|
2023-08-18 20:59:35 -07:00 |
|
Xuechen Li
|
7fcd3e6a04
|
map custom model state_dict back to huggingface format (#465)
* fix name.
* set inv function.
* add map back function.
* handle gqa.
* add type annotation to avoid confusion.
* fix docstr.
* test inverse remap logic.
|
2023-08-18 20:51:39 -07:00 |
|
Xuechen Li
|
bb4cded17b
|
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.
* trim.
* enable passing in number of heads for each rank.
* simplify.
* simplify.
* cleanup.
* fix col parallel.
* fix bug with row parallel.
* fit out proj.
* refac.
* fix sharding logic.
* refac sharding.
* refac.
* support multiple of.
* make fn reuseable.
* fix bug in dimensions.
* scaffold.
* test uneven heads.
* fix test by adding barrier.
* refac.
* reuse code.
* clean up.
|
2023-08-18 14:10:35 -07:00 |
|
Tri Dao
|
a81900d4c1
|
[ViT] Minor fix so it runs
|
2023-08-17 17:25:34 -07:00 |
|
Xuechen Li
|
0f7853c6a1
|
enable loading hf llama checkpoints for training (#446)
* prelim.
* add hf convertion fn.
* mlp.
* change name.
* fix bug.
* inverse permute.
* change comment.
* revert style changes.
* fix.
* add doc.
* revert.
* enable load safe.
* fix safe load.
* fix import.
* fix typing-related lints.
* fix ckpt loading logic.
* make single gpu work.
* test with parallel.
* ckpt format.
* enable pretrained state dict.
* remove unused imports.
* remove unused.
* mark idea related.
|
2023-08-15 08:33:15 -07:00 |
|
Tri Dao
|
184b992dcb
|
[GPT] Implement parallel LLaMa
|
2023-07-28 15:52:48 -10:00 |
|
Tri Dao
|
56ccaff126
|
[GPT] Add LLaMa-13B to test
|
2023-07-26 07:22:22 -10:00 |
|
Tri Dao
|
8e9820a55b
|
[Rotary] Fix tests when loading state dict with rotary inv_freqs
|
2023-07-26 07:16:33 -10:00 |
|
Tri Dao
|
d38357dd2f
|
[GPT] Implement Falcon
|
2023-07-23 10:32:29 -07:00 |
|
Tri Dao
|
425dbcb6c6
|
[MHA] Implement MQA/GQA
|
2023-07-23 00:06:58 -07:00 |
|
Tri Dao
|
b3177dfaf6
|
[GPT] Enable FlashAttention for GPT-J
|
2023-07-21 17:29:10 -07:00 |
|
Tri Dao
|
62e9814466
|
[Rotary] Make sure frequency calculation is in fp32
|
2023-07-02 16:39:39 -07:00 |
|
Tri Dao
|
48bc6eacd6
|
[Gen] Add rotary base as an argument to FT attention kernel
|
2023-05-30 13:38:34 -07:00 |
|
Tri Dao
|
a9a4b4e4f2
|
[LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm
|
2023-05-04 23:39:43 -07:00 |
|
Tri Dao
|
311d6606bf
|
[Gen] Fix FT kernel smem size, CG when batch size changed
|
2023-04-20 17:03:13 -07:00 |
|
Tri Dao
|
96d10f6545
|
Implement LLaMa
|
2023-04-18 21:51:35 -07:00 |
|
Tri Dao
|
605655bc66
|
[Gen] Fix FT kernel when using CG
|
2023-04-14 16:50:01 -07:00 |
|
Tri Dao
|
393882bc08
|
[LayerNorm] Implement LN with parallel residual, support dim 8k
|
2023-03-31 14:23:45 -07:00 |
|
Tri Dao
|
993d12448e
|
Implement GPT-NeoX
|
2023-03-29 01:21:25 -07:00 |
|
Tri Dao
|
f5d0fbd468
|
[FT] Fix FT's single query attention for bf16 hdim128 rotary
|
2023-03-28 21:27:00 -07:00 |
|
Tri Dao
|
4d87e4d875
|
Implement GPT-J
|
2023-03-22 16:16:58 -07:00 |
|
Tri Dao
|
78b7a1dc18
|
[OPT] Load fp16 weights on CPU before moving to GPU
|
2023-01-22 17:01:32 -08:00 |
|
Tri Dao
|
f68d41ec77
|
[Gen] Add OPT to generation test
|
2023-01-17 19:59:06 -08:00 |
|
Tri Dao
|
88173a1aaf
|
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
|
2023-01-17 18:12:27 -08:00 |
|
Tri Dao
|
780e8eeabb
|
[ViT] Support timm checkpoint, add tests
|
2023-01-16 01:20:34 -08:00 |
|
Tri Dao
|
ff34123bd4
|
Reorder LN in Block, support OPT
|
2023-01-15 22:14:31 -08:00 |
|
Tri Dao
|
f1e01c27ba
|
[Gen] Pass qkv_stride to ft_attention kernel for batched generation
|
2023-01-15 15:20:01 -08:00 |
|
Tri Dao
|
7c2191542a
|
[Gen] Make generation work with Tensor Parallel
|
2023-01-15 11:34:27 -08:00 |
|
Tri Dao
|
b48599002a
|
[Gen] Add timing option
|
2023-01-07 19:05:09 -08:00 |
|
Tri Dao
|
0938298e4c
|
[Gen] Adjust shape of kv_cache when using FT
|
2023-01-07 17:27:54 -08:00 |
|
Tri Dao
|
e02fd588aa
|
[Gen] Implement top-k and top-p sampling
|
2023-01-07 17:00:02 -08:00 |
|