[Core][Model] Support loading weights by ID within models (#7931)
This commit is contained in:
parent
b8747e8a7c
commit
3f06bae907
@ -1,6 +1,7 @@
|
|||||||
# ruff: noqa: SIM117
|
# ruff: noqa: SIM117
|
||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
|
import dataclasses
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
@ -8,7 +9,8 @@ import math
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
|
from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple,
|
||||||
|
Type, cast)
|
||||||
|
|
||||||
import gguf
|
import gguf
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
@ -207,6 +209,22 @@ class BaseModelLoader(ABC):
|
|||||||
class DefaultModelLoader(BaseModelLoader):
|
class DefaultModelLoader(BaseModelLoader):
|
||||||
"""Model loader that can load different file types from disk."""
|
"""Model loader that can load different file types from disk."""
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Source:
|
||||||
|
"""A source for weights."""
|
||||||
|
|
||||||
|
model_or_path: str
|
||||||
|
"""The model ID or path."""
|
||||||
|
|
||||||
|
revision: Optional[str]
|
||||||
|
"""The optional model revision."""
|
||||||
|
|
||||||
|
prefix: str = ""
|
||||||
|
"""A prefix to prepend to all weights."""
|
||||||
|
|
||||||
|
fall_back_to_pt: bool = True
|
||||||
|
"""Whether .pt weights can be used."""
|
||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
super().__init__(load_config)
|
super().__init__(load_config)
|
||||||
if load_config.model_loader_extra_config:
|
if load_config.model_loader_extra_config:
|
||||||
@ -313,17 +331,16 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
return hf_folder, hf_weights_files, use_safetensors
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
def _get_weights_iterator(
|
def _get_weights_iterator(
|
||||||
self, model_name_or_path: str, revision: Optional[str],
|
self, source: "Source"
|
||||||
fall_back_to_pt: bool
|
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
"""Get an iterator for the model weights based on the load format."""
|
"""Get an iterator for the model weights based on the load format."""
|
||||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||||
model_name_or_path, revision, fall_back_to_pt)
|
source.model_or_path, source.revision, source.fall_back_to_pt)
|
||||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||||
# Currently np_cache only support *.bin checkpoints
|
# Currently np_cache only support *.bin checkpoints
|
||||||
assert use_safetensors is False
|
assert use_safetensors is False
|
||||||
weights_iterator = np_cache_weights_iterator(
|
weights_iterator = np_cache_weights_iterator(
|
||||||
model_name_or_path, self.load_config.download_dir, hf_folder,
|
source.model_or_path, self.load_config.download_dir, hf_folder,
|
||||||
hf_weights_files)
|
hf_weights_files)
|
||||||
elif use_safetensors:
|
elif use_safetensors:
|
||||||
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
||||||
@ -341,7 +358,29 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
|
|
||||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||||
return weights_iterator
|
|
||||||
|
# Apply the prefix.
|
||||||
|
return ((source.prefix + name, tensor)
|
||||||
|
for (name, tensor) in weights_iterator)
|
||||||
|
|
||||||
|
def _get_all_weights(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
|
||||||
|
primary_weights = DefaultModelLoader.Source(
|
||||||
|
model_config.model,
|
||||||
|
model_config.revision,
|
||||||
|
prefix="",
|
||||||
|
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
|
||||||
|
True))
|
||||||
|
yield from self._get_weights_iterator(primary_weights)
|
||||||
|
|
||||||
|
secondary_weights = cast(Iterable[DefaultModelLoader.Source],
|
||||||
|
getattr(model, "secondary_weights", ()))
|
||||||
|
for source in secondary_weights:
|
||||||
|
yield from self._get_weights_iterator(source)
|
||||||
|
|
||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
self._prepare_weights(model_config.model,
|
self._prepare_weights(model_config.model,
|
||||||
@ -360,13 +399,8 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
model = _initialize_model(model_config, self.load_config,
|
model = _initialize_model(model_config, self.load_config,
|
||||||
lora_config, cache_config,
|
lora_config, cache_config,
|
||||||
scheduler_config)
|
scheduler_config)
|
||||||
model.load_weights(
|
|
||||||
self._get_weights_iterator(model_config.model,
|
model.load_weights(self._get_all_weights(model_config, model))
|
||||||
model_config.revision,
|
|
||||||
fall_back_to_pt=getattr(
|
|
||||||
model,
|
|
||||||
"fall_back_to_pt_during_load",
|
|
||||||
True)), )
|
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
quant_method = getattr(module, "quant_method", None)
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.model_executor.models.utils import (flatten_bn,
|
from vllm.model_executor.models.utils import (flatten_bn,
|
||||||
@ -334,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
|||||||
self.multi_modal_config = multimodal_config
|
self.multi_modal_config = multimodal_config
|
||||||
assert self.multi_modal_config
|
assert self.multi_modal_config
|
||||||
|
|
||||||
|
self.secondary_weights = []
|
||||||
|
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
||||||
if config.audio_model_id is not None:
|
if config.audio_model_id is not None:
|
||||||
self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
|
self.secondary_weights.append(
|
||||||
config.audio_model_id)
|
DefaultModelLoader.Source(
|
||||||
else:
|
model_or_path=config.audio_model_id,
|
||||||
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
revision=None,
|
||||||
|
prefix="audio_tower.",
|
||||||
|
))
|
||||||
self.multi_modal_projector = UltravoxProjector(config)
|
self.multi_modal_projector = UltravoxProjector(config)
|
||||||
self.language_model = init_vllm_registered_model(
|
self.language_model = init_vllm_registered_model(
|
||||||
config.text_config, cache_config, quant_config)
|
config.text_config, cache_config, quant_config)
|
||||||
|
if config.text_model_id is not None:
|
||||||
|
self.secondary_weights.append(
|
||||||
|
DefaultModelLoader.Source(model_or_path=config.text_model_id,
|
||||||
|
revision=None,
|
||||||
|
prefix="language_model."))
|
||||||
|
|
||||||
def _audio_features_to_embeddings(
|
def _audio_features_to_embeddings(
|
||||||
self, input_features: torch.Tensor) -> torch.Tensor:
|
self, input_features: torch.Tensor) -> torch.Tensor:
|
||||||
@ -466,6 +476,18 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
|||||||
# prepare weight iterators for components
|
# prepare weight iterators for components
|
||||||
weights_group = group_weights_with_prefix(weights)
|
weights_group = group_weights_with_prefix(weights)
|
||||||
|
|
||||||
|
# load audio tower weights
|
||||||
|
audio_tower_weights = weights_group["audio_tower"]
|
||||||
|
audio_tower_params_dict = dict(
|
||||||
|
self.audio_tower.named_parameters(
|
||||||
|
prefix=self.audio_tower.base_model_prefix))
|
||||||
|
for name, loaded_weight in audio_tower_weights:
|
||||||
|
if name in audio_tower_params_dict:
|
||||||
|
param = audio_tower_params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
# load projector weights
|
# load projector weights
|
||||||
projector_weights = weights_group["multi_modal_projector"]
|
projector_weights = weights_group["multi_modal_projector"]
|
||||||
projector_params_dict = dict(
|
projector_params_dict = dict(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user