diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index 078fcab..b4fdd53 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -20,10 +20,8 @@ def test_gptj_state_dict(model_name): pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow state_dict = model.state_dict() - rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq' - for l in range(config.n_layer)} - assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys - for k in state_dict.keys() - rotary_inv_freq_keys: + assert state_dict.keys() == pretrained_state_dict.keys() + for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index 7f99ef1..a6b9c6c 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -2,7 +2,7 @@ # To run the huggingface implementation, we first need to convert the weights: # https://github.com/huggingface/transformers/pull/21955 -# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR$/llama/7B-hf +# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf # and repeat for 13B, 30B, 65B import os @@ -32,10 +32,8 @@ def test_llama_state_dict(model_name): pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config) model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow state_dict = model.state_dict() - rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq' - for l in range(config.n_layer)} - assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys - for k in state_dict.keys() - rotary_inv_freq_keys: + assert state_dict.keys() == pretrained_state_dict.keys() + for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape