From 1937e29848c8de8634c5421612d57863aa0e2a51 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 20 May 2024 14:46:12 -0400 Subject: [PATCH] [Core] Sharded State Loader download from HF (#4889) --- vllm/model_executor/model_loader/loader.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d1ab2075..45ea8160 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -423,6 +423,16 @@ class ShardedStateLoader(BaseModelLoader): result[k] = t return result + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf(model_name_or_path, + self.load_config.download_dir, + allow_patterns, revision) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -433,6 +443,10 @@ class ShardedStateLoader(BaseModelLoader): from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank + + local_model_path = self._prepare_weights(model_config.model, + model_config.revision) + with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, @@ -440,7 +454,7 @@ class ShardedStateLoader(BaseModelLoader): cache_config) rank = get_tensor_model_parallel_rank() pattern = os.path.join( - model_config.model, + local_model_path, self.pattern.format(rank=rank, part="*"), ) filepaths = glob.glob(pattern)