[Rotary] Fix tests when loading state dict with rotary inv_freqs
This commit is contained in:
parent
b252072409
commit
8e9820a55b
@ -20,10 +20,8 @@ def test_gptj_state_dict(model_name):
|
||||
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
|
||||
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
|
||||
state_dict = model.state_dict()
|
||||
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
|
||||
for l in range(config.n_layer)}
|
||||
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
|
||||
for k in state_dict.keys() - rotary_inv_freq_keys:
|
||||
assert state_dict.keys() == pretrained_state_dict.keys()
|
||||
for k in state_dict.keys():
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
# To run the huggingface implementation, we first need to convert the weights:
|
||||
# https://github.com/huggingface/transformers/pull/21955
|
||||
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR$/llama/7B-hf
|
||||
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
|
||||
# and repeat for 13B, 30B, 65B
|
||||
|
||||
import os
|
||||
@ -32,10 +32,8 @@ def test_llama_state_dict(model_name):
|
||||
pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
|
||||
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
|
||||
state_dict = model.state_dict()
|
||||
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
|
||||
for l in range(config.n_layer)}
|
||||
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
|
||||
for k in state_dict.keys() - rotary_inv_freq_keys:
|
||||
assert state_dict.keys() == pretrained_state_dict.keys()
|
||||
for k in state_dict.keys():
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user