[Core] Sharded State Loader download from HF (#4889)
This commit is contained in:
parent
f0eecee610
commit
1937e29848
@ -423,6 +423,16 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
result[k] = t
|
result[k] = t
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _prepare_weights(self, model_name_or_path: str,
|
||||||
|
revision: Optional[str]):
|
||||||
|
if os.path.isdir(model_name_or_path):
|
||||||
|
return model_name_or_path
|
||||||
|
else:
|
||||||
|
allow_patterns = ["*.safetensors"]
|
||||||
|
return download_weights_from_hf(model_name_or_path,
|
||||||
|
self.load_config.download_dir,
|
||||||
|
allow_patterns, revision)
|
||||||
|
|
||||||
def load_model(self, *, model_config: ModelConfig,
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
@ -433,6 +443,10 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
|
local_model_path = self._prepare_weights(model_config.model,
|
||||||
|
model_config.revision)
|
||||||
|
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model = _initialize_model(model_config, self.load_config,
|
model = _initialize_model(model_config, self.load_config,
|
||||||
@ -440,7 +454,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
cache_config)
|
cache_config)
|
||||||
rank = get_tensor_model_parallel_rank()
|
rank = get_tensor_model_parallel_rank()
|
||||||
pattern = os.path.join(
|
pattern = os.path.join(
|
||||||
model_config.model,
|
local_model_path,
|
||||||
self.pattern.format(rank=rank, part="*"),
|
self.pattern.format(rank=rank, part="*"),
|
||||||
)
|
)
|
||||||
filepaths = glob.glob(pattern)
|
filepaths = glob.glob(pattern)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user