[Core] Sharded State Loader download from HF (#4889)
This commit is contained in:
parent
f0eecee610
commit
1937e29848
@ -423,6 +423,16 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
result[k] = t
|
||||
return result
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]):
|
||||
if os.path.isdir(model_name_or_path):
|
||||
return model_name_or_path
|
||||
else:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
return download_weights_from_hf(model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns, revision)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
@ -433,6 +443,10 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
local_model_path = self._prepare_weights(model_config.model,
|
||||
model_config.revision)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
@ -440,7 +454,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
cache_config)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pattern = os.path.join(
|
||||
model_config.model,
|
||||
local_model_path,
|
||||
self.pattern.format(rank=rank, part="*"),
|
||||
)
|
||||
filepaths = glob.glob(pattern)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user