[GPT] Fix loading weights from HF hub

This commit is contained in:
Tri Dao 2023-08-21 22:56:02 -07:00
parent a8c35b4f57
commit ef6d8c75d9
2 changed files with 9 additions and 2 deletions

View File

@ -44,6 +44,14 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
)
is_sharded = True
load_safe = True
else: # Try loading from HF hub instead of from local files
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is None:
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
raise EnvironmentError(f"Model name {model_name} was not found.")

View File

@ -43,8 +43,7 @@ def get_hf_models(model_name, config, dtype):
return model_hf
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
@pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF