[Core][Model] Support loading weights by ID within models (#7931)

This commit is contained in:
Peter Salas 2024-09-24 00:14:15 -07:00 committed by GitHub
parent b8747e8a7c
commit 3f06bae907
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 17 deletions

View File

@ -1,6 +1,7 @@
# ruff: noqa: SIM117
import collections
import copy
import dataclasses
import fnmatch
import glob
import json
@ -8,7 +9,8 @@ import math
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple,
Type, cast)
import gguf
import huggingface_hub
@ -207,6 +209,22 @@ class BaseModelLoader(ABC):
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: Optional[str]
"""The optional model revision."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
@ -313,17 +331,16 @@ class DefaultModelLoader(BaseModelLoader):
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, model_name_or_path: str, revision: Optional[str],
fall_back_to_pt: bool
self, source: "Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision, fall_back_to_pt)
source.model_or_path, source.revision, source.fall_back_to_pt)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
model_name_or_path, self.load_config.download_dir, hf_folder,
source.model_or_path, self.load_config.download_dir, hf_folder,
hf_weights_files)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
@ -341,7 +358,29 @@ class DefaultModelLoader(BaseModelLoader):
xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator
# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
def _get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True))
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()))
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
@ -360,13 +399,8 @@ class DefaultModelLoader(BaseModelLoader):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
fall_back_to_pt=getattr(
model,
"fall_back_to_pt_during_load",
True)), )
model.load_weights(self._get_all_weights(model_config, model))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)

View File

@ -25,6 +25,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (flatten_bn,
@ -334,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
self.multi_modal_config = multimodal_config
assert self.multi_modal_config
if config.audio_model_id is not None:
self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id)
else:
self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None:
self.secondary_weights.append(
DefaultModelLoader.Source(
model_or_path=config.audio_model_id,
revision=None,
prefix="audio_tower.",
))
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
if config.text_model_id is not None:
self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id,
revision=None,
prefix="language_model."))
def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
@ -466,6 +476,18 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load audio tower weights
audio_tower_weights = weights_group["audio_tower"]
audio_tower_params_dict = dict(
self.audio_tower.named_parameters(
prefix=self.audio_tower.base_model_prefix))
for name, loaded_weight in audio_tower_weights:
if name in audio_tower_params_dict:
param = audio_tower_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load projector weights
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(