From ef6d8c75d92f38ea9da4e8dc198d8c476a697902 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 21 Aug 2023 22:56:02 -0700 Subject: [PATCH] [GPT] Fix loading weights from HF hub --- flash_attn/utils/pretrained.py | 8 ++++++++ tests/models/test_bert.py | 3 +-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/flash_attn/utils/pretrained.py b/flash_attn/utils/pretrained.py index 01c5b28..40e76bd 100644 --- a/flash_attn/utils/pretrained.py +++ b/flash_attn/utils/pretrained.py @@ -44,6 +44,14 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None): ) is_sharded = True load_safe = True + else: # Try loading from HF hub instead of from local files + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, + _raise_exceptions_for_missing_entries=False) + if resolved_archive_file is None: + resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, + _raise_exceptions_for_missing_entries=False) + if resolved_archive_file is not None: + is_sharded = True if resolved_archive_file is None: raise EnvironmentError(f"Model name {model_name} was not found.") diff --git a/tests/models/test_bert.py b/tests/models/test_bert.py index 0bb26e1..f53cfb9 100644 --- a/tests/models/test_bert.py +++ b/tests/models/test_bert.py @@ -43,8 +43,7 @@ def get_hf_models(model_name, config, dtype): return model_hf -@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"]) -# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) +@pytest.mark.parametrize('model_name', ["bert-base-uncased"]) def test_bert_non_optimized(model_name): """Check that our implementation of BERT (without any optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF