[Bugfix] Fix Mistral v0.3 Weight Loading (#5005)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
6a50f4cafa
commit
919770957f
@ -8,6 +8,7 @@ from .utils import check_logprobs_close
|
||||
|
||||
MODELS = [
|
||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -23,7 +23,8 @@ from vllm.model_executor.model_loader.tensorizer import (
|
||||
from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
||||
set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf, filter_files_not_needed_for_inference,
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
||||
@ -188,7 +189,19 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if not use_safetensors:
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, self.load_config.download_dir,
|
||||
revision)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(
|
||||
hf_weights_files)
|
||||
|
||||
|
||||
@ -12,9 +12,10 @@ import filelock
|
||||
import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, snapshot_download
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -218,6 +219,67 @@ def download_weights_from_hf(
|
||||
return hf_folder
|
||||
|
||||
|
||||
def download_safetensors_index_file_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Download hf safetensors index file from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
"""
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
try:
|
||||
# Download the safetensors index file.
|
||||
hf_hub_download(
|
||||
repo_id=model_name_or_path,
|
||||
filename=SAFE_WEIGHTS_INDEX_NAME,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
# If file not found on remote or locally, we should not fail since
|
||||
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
|
||||
except huggingface_hub.utils.EntryNotFoundError:
|
||||
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
|
||||
except huggingface_hub.utils.LocalEntryNotFoundError:
|
||||
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)
|
||||
|
||||
|
||||
# For models like Mistral-7B-v0.3, there are both sharded
|
||||
# safetensors files and a consolidated safetensors file.
|
||||
# Passing both of these to the weight loader functionality breaks.
|
||||
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
|
||||
# look up which safetensors files should be used.
|
||||
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
|
||||
hf_folder: str) -> List[str]:
|
||||
# model.safetensors.index.json is a mapping from keys in the
|
||||
# torch state_dict to safetensors file holding that weight.
|
||||
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if not os.path.isfile(index_file_name):
|
||||
return hf_weights_files
|
||||
|
||||
# Iterate through the weight_map (weight_name: safetensors files)
|
||||
# to identify weights that we should use.
|
||||
with open(index_file_name) as index_file:
|
||||
weight_map = json.load(index_file)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(
|
||||
os.path.join(hf_folder, weight_map[weight_name]))
|
||||
# Filter out any fields that are not found in the index file.
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files if f in weight_files_in_index
|
||||
]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
def filter_files_not_needed_for_inference(
|
||||
hf_weights_files: List[str]) -> List[str]:
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user