[Core] Sharded State Loader download from HF (#4889)

This commit is contained in:
Aurick Qiao 2024-05-20 14:46:12 -04:00 committed by GitHub
parent f0eecee610
commit 1937e29848
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -423,6 +423,16 @@ class ShardedStateLoader(BaseModelLoader):
result[k] = t result[k] = t
return result return result
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]):
if os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
return download_weights_from_hf(model_name_or_path,
self.load_config.download_dir,
allow_patterns, revision)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
@ -433,6 +443,10 @@ class ShardedStateLoader(BaseModelLoader):
from safetensors.torch import safe_open from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights(model_config.model,
model_config.revision)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
@ -440,7 +454,7 @@ class ShardedStateLoader(BaseModelLoader):
cache_config) cache_config)
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
pattern = os.path.join( pattern = os.path.join(
model_config.model, local_model_path,
self.pattern.format(rank=rank, part="*"), self.pattern.format(rank=rank, part="*"),
) )
filepaths = glob.glob(pattern) filepaths = glob.glob(pattern)