[Model][VLM] Decouple weight loading logic for Paligemma (#8269)
This commit is contained in:
parent
e807125936
commit
36bf8150cc
@ -1,3 +1,4 @@
|
|||||||
|
import itertools
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
|
|
||||||
@ -13,7 +14,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.gemma import GemmaModel
|
from vllm.model_executor.models.gemma import GemmaForCausalLM
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
@ -22,14 +23,10 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||||
from .utils import merge_multimodal_embeddings
|
from .utils import filter_weights, merge_multimodal_embeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
|
||||||
"language_model.model": "language_model",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaImagePixelInputs(TypedDict):
|
class PaliGemmaImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
@ -151,8 +148,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
projection_dim=config.vision_config.projection_dim)
|
projection_dim=config.vision_config.projection_dim)
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.language_model = GemmaModel(config.text_config, cache_config,
|
self.language_model = GemmaForCausalLM(config.text_config,
|
||||||
quant_config)
|
cache_config, quant_config)
|
||||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
@ -252,7 +249,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
vision_embeddings = vision_embeddings * (self.config.hidden_size**
|
vision_embeddings = vision_embeddings * (self.config.hidden_size**
|
||||||
-0.5)
|
-0.5)
|
||||||
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
|
input_ids)
|
||||||
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
@ -262,87 +260,47 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
|
||||||
hidden_states = self.language_model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
None,
|
None,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
# Copied from vllm/model_executor/models/gemma.py
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
logits = self.logits_processor(self.language_model.embed_tokens,
|
return self.language_model.compute_logits(hidden_states,
|
||||||
hidden_states, sampling_metadata)
|
sampling_metadata)
|
||||||
return logits
|
|
||||||
|
|
||||||
# Copied from vllm/model_executor/models/gemma.py
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
# Adapted from vllm/model_executor/models/gemma.py
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
# prepare weight iterators for components
|
||||||
# (param_name, shard_name, shard_id)
|
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
|
||||||
if key_to_modify in name:
|
|
||||||
name = name.replace(key_to_modify, new_key)
|
|
||||||
use_default_weight_loading = False
|
|
||||||
if "vision" not in name or self.vision_tower.shard_weight:
|
|
||||||
for (param_name, shard_name,
|
|
||||||
shard_id) in stacked_params_mapping:
|
|
||||||
if shard_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(shard_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# lm_head is not used in vllm as it is tied with
|
|
||||||
# embed_token. To prevent errors, skip loading
|
|
||||||
# lm_head.weight.
|
|
||||||
if "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
use_default_weight_loading = True
|
|
||||||
else:
|
|
||||||
use_default_weight_loading = True
|
|
||||||
|
|
||||||
if use_default_weight_loading:
|
# load vision tower
|
||||||
param = params_dict[name]
|
vit_weights = filter_weights(vit_weights, "vision_tower")
|
||||||
weight_loader = getattr(param, "weight_loader",
|
self.vision_tower.load_weights(vit_weights)
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
|
||||||
loaded_params.add(name)
|
# load mlp projector
|
||||||
|
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
||||||
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||||
|
for name, loaded_weight in mlp_weights:
|
||||||
|
param = mlp_params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
unloaded_params = params_dict.keys() - loaded_params
|
# load llm backbone
|
||||||
if unloaded_params:
|
llm_weights = filter_weights(llm_weights, "language_model")
|
||||||
logger.warning(
|
self.language_model.load_weights(llm_weights)
|
||||||
"Some weights are not initialized from checkpoints: %s",
|
|
||||||
unloaded_params)
|
|
||||||
|
|||||||
@ -529,6 +529,12 @@ class SiglipVisionModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
] if self.shard_weight else []
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
layer_count = len(self.vision_model.encoder.layers)
|
layer_count = len(self.vision_model.encoder.layers)
|
||||||
|
|
||||||
@ -544,7 +550,16 @@ class SiglipVisionModel(nn.Module):
|
|||||||
if layer_idx >= layer_count:
|
if layer_idx >= layer_count:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
weight_loader = getattr(param, "weight_loader",
|
if weight_name not in name:
|
||||||
default_weight_loader)
|
continue
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
param = params_dict[name.replace(weight_name, param_name)]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user