[Rotary] Fix tests when loading state dict with rotary inv_freqs

This commit is contained in:
Tri Dao 2023-07-26 07:16:10 -10:00
parent b252072409
commit 8e9820a55b
2 changed files with 5 additions and 9 deletions

View File

@ -20,10 +20,8 @@ def test_gptj_state_dict(model_name):
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
state_dict = model.state_dict()
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
for l in range(config.n_layer)}
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
for k in state_dict.keys() - rotary_inv_freq_keys:
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape

View File

@ -2,7 +2,7 @@
# To run the huggingface implementation, we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR$/llama/7B-hf
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
# and repeat for 13B, 30B, 65B
import os
@ -32,10 +32,8 @@ def test_llama_state_dict(model_name):
pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
state_dict = model.state_dict()
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
for l in range(config.n_layer)}
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
for k in state_dict.keys() - rotary_inv_freq_keys:
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape