[GPT] Fix loading weights from HF hub
This commit is contained in:
parent
a8c35b4f57
commit
ef6d8c75d9
@ -44,6 +44,14 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
|
||||
)
|
||||
is_sharded = True
|
||||
load_safe = True
|
||||
else: # Try loading from HF hub instead of from local files
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
if resolved_archive_file is None:
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError(f"Model name {model_name} was not found.")
|
||||
|
||||
@ -43,8 +43,7 @@ def get_hf_models(model_name, config, dtype):
|
||||
return model_hf
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
|
||||
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
|
||||
@pytest.mark.parametrize('model_name', ["bert-base-uncased"])
|
||||
def test_bert_non_optimized(model_name):
|
||||
"""Check that our implementation of BERT (without any optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
|
||||
Loading…
Reference in New Issue
Block a user