[Rotary] Fix tests when loading state dict with rotary inv_freqs
This commit is contained in:
parent
b252072409
commit
8e9820a55b
@ -20,10 +20,8 @@ def test_gptj_state_dict(model_name):
|
|||||||
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
|
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
|
||||||
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
|
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
|
assert state_dict.keys() == pretrained_state_dict.keys()
|
||||||
for l in range(config.n_layer)}
|
for k in state_dict.keys():
|
||||||
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
|
|
||||||
for k in state_dict.keys() - rotary_inv_freq_keys:
|
|
||||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
# To run the huggingface implementation, we first need to convert the weights:
|
# To run the huggingface implementation, we first need to convert the weights:
|
||||||
# https://github.com/huggingface/transformers/pull/21955
|
# https://github.com/huggingface/transformers/pull/21955
|
||||||
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR$/llama/7B-hf
|
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
|
||||||
# and repeat for 13B, 30B, 65B
|
# and repeat for 13B, 30B, 65B
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -32,10 +32,8 @@ def test_llama_state_dict(model_name):
|
|||||||
pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
|
pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
|
||||||
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
|
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
|
assert state_dict.keys() == pretrained_state_dict.keys()
|
||||||
for l in range(config.n_layer)}
|
for k in state_dict.keys():
|
||||||
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
|
|
||||||
for k in state_dict.keys() - rotary_inv_freq_keys:
|
|
||||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user