JDKWangGuan
0d810cfb73
Fix KeyError handling for non-existing key in state_dict.pop() ( #898 )
...
Update handling for KeyError in state_dict.pop() for non-existing keys.
Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.
The following code can re-produce the issue
```
from transformers import AutoTokenizer, GPT2Model, GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel
# >>> transformers.__version__
# '4.38.2'
model_path = 'gpt2'
output_model_path = 'gpt2_model'
config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
'''
model fine-tuning here
'''
# dump the fine-tuned model
model.save_pretrained(output_model_path)
# load the fine-tuned model
config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
model = GPTModel.from_pretrained(output_model_path, config=config, strict=True) # failed due to KeyError: 'h.0.attn.bias'
model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True) # failed due to KeyError: 'h.0.attn.bias'
```
2024-06-30 22:40:03 -07:00
Tri Dao
ef0ed10622
Add window_size option to MHA and GPT
2024-01-31 02:42:23 -08:00
Tri Dao
abbc131173
[LayerNorm] Switch from CUDA to Triton implementation
2024-01-05 00:31:17 -08:00
Tri Dao
73df3be7d5
Add test for BTLM init
2023-12-25 15:16:27 -08:00
Tri Dao
2e29dacf0c
Implement muParam
2023-12-24 20:34:48 -08:00
Tri Dao
2c7d7b7396
Implement norm head for Baichuan2
2023-12-22 16:55:40 -08:00
Tri Dao
c3b2196652
Add Alibi to MHA, test with Baichuan-13B
2023-12-21 22:49:55 -08:00
Yuchao Dai
187c2a0635
Fix E1136 ( #563 )
2023-09-21 11:48:23 -07:00
Tri Dao
d0032700d1
Add tests for Pythia, GPT-JT, and RedPajama models
2023-09-13 01:10:39 -07:00
Kevin Hu
07005806ff
Add BigCode converters ( #532 )
2023-09-10 17:24:50 -07:00
Tri Dao
798858f9f1
Fix test_baichuan
2023-09-03 21:01:37 -07:00
Tri Dao
7b33743a72
[Gen] Add back num_last_tokens in gpt.py
2023-09-03 20:44:40 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding ( #490 )
...
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
Tri Dao
f8aea6ead0
[GPT] Generalize last_token_only arg to num_last_tokens
2023-08-26 20:47:53 -07:00
Aman Gupta Karmani
e0b09891c6
add llama support to GPTPreTrainedModel.from_pretrained ( #479 )
2023-08-24 16:31:16 -07:00
Xuechen Li
25d6b1dbcb
handle uneven heads across ranks when combining state_dicts; resolves #467 ( #468 )
...
* q
* add comment.
2023-08-20 14:57:34 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format ( #465 )
...
* fix name.
* set inv function.
* add map back function.
* handle gqa.
* add type annotation to avoid confusion.
* fix docstr.
* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Tri Dao
f1a73d0740
Run isort and black on python files
2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 ( #461 )
...
* uneql rank.
* trim.
* enable passing in number of heads for each rank.
* simplify.
* simplify.
* cleanup.
* fix col parallel.
* fix bug with row parallel.
* fit out proj.
* refac.
* fix sharding logic.
* refac sharding.
* refac.
* support multiple of.
* make fn reuseable.
* fix bug in dimensions.
* scaffold.
* test uneven heads.
* fix test by adding barrier.
* refac.
* reuse code.
* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
4b661a569d
[GPT] Run black on gpt.py
2023-08-16 23:47:50 -07:00
Tri Dao
184b992dcb
[GPT] Implement parallel LLaMa
2023-07-28 15:52:48 -10:00
Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp ( #251 )
2023-07-26 12:14:15 -07:00
Tri Dao
d38357dd2f
[GPT] Implement Falcon
2023-07-23 10:32:29 -07:00
Tri Dao
425dbcb6c6
[MHA] Implement MQA/GQA
2023-07-23 00:06:58 -07:00
Tri Dao
ec9f74ab9a
[Rotary] Don't store inv_freq in state_dict
2023-07-22 23:52:42 -07:00
Tri Dao
75e334d407
[MLP] Add ParallelMLP
2023-07-22 23:45:51 -07:00
Tri Dao
48bc6eacd6
[Gen] Add rotary base as an argument to FT attention kernel
2023-05-30 13:38:34 -07:00
Federico Berto
69f5f7d0a2
[BugFix] cannot unpack non-iterable NoneType object
2023-05-07 03:07:44 +09:00
Tri Dao
a9a4b4e4f2
[LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm
2023-05-04 23:39:43 -07:00
Tri Dao
ba2fe7f378
[Gen] Move allocate_inference_cache to within the model
2023-04-20 18:15:12 -07:00
Tri Dao
3da42d24b1
[GPT] Add option to only return the logit for the last token
2023-04-20 17:21:08 -07:00
Tri Dao
96d10f6545
Implement LLaMa
2023-04-18 21:51:35 -07:00
Tri Dao
b630aef53f
Implement GatedMlp
2023-04-18 03:37:14 -07:00
Tri Dao
6f6e9a9aaf
[FusedDense] Enable sqrelu activation in FusedMLP
2023-04-13 15:29:32 -07:00
Tri Dao
393882bc08
[LayerNorm] Implement LN with parallel residual, support dim 8k
2023-03-31 14:23:45 -07:00
Tri Dao
993d12448e
Implement GPT-NeoX
2023-03-29 01:21:25 -07:00
Tri Dao
4d87e4d875
Implement GPT-J
2023-03-22 16:16:58 -07:00
Tri Dao
78b7a1dc18
[OPT] Load fp16 weights on CPU before moving to GPU
2023-01-22 17:01:32 -08:00
Tri Dao
eb33e587e9
[LayerNorm] Rename x1 -> residual
2023-01-19 13:07:27 -08:00
Tri Dao
88173a1aaf
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
2023-01-17 18:12:27 -08:00
Tri Dao
ff34123bd4
Reorder LN in Block, support OPT
2023-01-15 22:14:31 -08:00
Tri Dao
7c2191542a
[Gen] Make generation work with Tensor Parallel
2023-01-15 11:34:27 -08:00
Tri Dao
11be742aa3
[Gen] Test generation with rotary embedding
2023-01-07 14:37:54 -08:00
Tri Dao
93383bd55b
[TP] Implement TensorParallel without sequence parallel
2023-01-07 13:45:22 -08:00
Tri Dao
714c1b4f0f
[Bert] Fix embedding layer norm before embedding dropout
2023-01-01 10:38:05 -08:00
Tri Dao
ef1ba918c6
[GPT] Refactor function to shard state_dict for TensorParallel
2023-01-01 00:09:33 -08:00
Tri Dao
63670fd84a
Implement generation for GPT
2022-12-27 21:01:50 -08:00
Tri Dao
9d797d8848
Support loading GPT2 weights from Huggingface
2022-12-27 11:22:48 -08:00
Tri Dao
b4018a5028
Implement Tensor Parallel for GPT model
2022-12-26 16:22:43 -08:00
Tri Dao
e68ebbe89a
Simplify FusedDense
2022-12-22 21:25:31 -08:00