[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)

Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung 2024-10-23 19:27:37 +08:00 committed by GitHub
parent 3ff57ebfca
commit c18e1a3418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 551 additions and 253 deletions

View File

@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@ -21,10 +22,12 @@ class AWQConfig(QuantizationConfig):
weight_bits: int, weight_bits: int,
group_size: int, group_size: int,
zero_point: bool, zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
) -> None: ) -> None:
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.zero_point = zero_point self.zero_point = zero_point
self.modules_to_not_convert = modules_to_not_convert or []
if self.weight_bits != 4: if self.weight_bits != 4:
raise ValueError( raise ValueError(
@ -35,7 +38,8 @@ class AWQConfig(QuantizationConfig):
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, " return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, " f"group_size={self.group_size}, "
f"zero_point={self.zero_point})") f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
def get_name(self) -> str: def get_name(self) -> str:
return "awq" return "awq"
@ -61,11 +65,15 @@ class AWQConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"]) zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point) modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQLinearMethod"]: prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearMethod(self) return AWQLinearMethod(self)
return None return None
@ -73,6 +81,10 @@ class AWQConfig(QuantizationConfig):
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
class AWQLinearMethod(LinearMethodBase): class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ. """Linear method for AWQ.

View File

@ -122,7 +122,7 @@ def input_processor_for_blip(
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module): class BlipVisionEmbeddings(nn.Module):
def __init__(self, config: BlipVisionConfig): def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
super().__init__() super().__init__()
self.config = config self.config = config
@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module):
def __init__( def __init__(
self, self,
config: BlipVisionConfig, config: Union[BlipVisionConfig, Blip2VisionConfig],
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -189,11 +190,13 @@ class BlipParallelAttention(nn.Module):
self.num_heads, self.num_heads,
bias=config.qkv_bias, bias=config.qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv",
) )
self.projection = RowParallelLinear( self.projection = RowParallelLinear(
self.embed_dim, self.embed_dim,
self.embed_dim, self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.projection",
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@ -235,9 +238,12 @@ class BlipParallelAttention(nn.Module):
class BlipMLP(nn.Module): class BlipMLP(nn.Module):
def __init__(self, def __init__(
config: BlipVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -246,11 +252,13 @@ class BlipMLP(nn.Module):
self.fc1 = ColumnParallelLinear(config.hidden_size, self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size, self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
@ -262,24 +270,32 @@ class BlipMLP(nn.Module):
class BlipEncoderLayer(nn.Module): class BlipEncoderLayer(nn.Module):
def __init__(self, def __init__(
config: BlipVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
# fallback to sdpa attention if tp unavailable # fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0: if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config, self.self_attn = BlipParallelAttention(
quant_config=quant_config) config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else: else:
# Blip doesn't have SDPA attention implemented in transformers # Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend # use eager attention instead for cpu backend
self.self_attn = BlipAttention(config) self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config) self.mlp = BlipMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size, self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@ -307,10 +323,13 @@ class BlipEncoder(nn.Module):
config: BlipConfig config: BlipConfig
""" """
def __init__(self, def __init__(
config: BlipVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: BlipVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -321,8 +340,10 @@ class BlipEncoder(nn.Module):
num_hidden_layers = num_hidden_layers_override num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BlipEncoderLayer(config=config, quant_config=quant_config) BlipEncoderLayer(config=config,
for _ in range(num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
]) ])
def forward(self, inputs_embeds: torch.Tensor): def forward(self, inputs_embeds: torch.Tensor):
@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module):
config_class = BlipVisionConfig config_class = BlipVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
def __init__(self, def __init__(
config: BlipVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: BlipVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@ -354,19 +380,24 @@ class BlipVisionModel(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
) )
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers: if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError( raise ValueError(
f"The original encoder only has {config.num_hidden_layers} " f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers." f"layers, but you requested {len(self.encoder.layers)} layers."
) )
elif len(self.encoder.layers) == config.num_hidden_layers:
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(config.hidden_size, self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
else: else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None self.post_layernorm = None
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:

View File

@ -490,7 +490,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(config.vision_config) self.vision_model = BlipVisionModel(config.vision_config, quant_config)
self.query_tokens = nn.Parameter( self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, torch.zeros(1, config.num_query_tokens,

View File

@ -192,6 +192,7 @@ class CLIPParallelAttention(nn.Module):
self, self,
config: CLIPVisionConfig, config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -211,12 +212,14 @@ class CLIPParallelAttention(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.num_heads, total_num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
input_size=self.embed_dim, input_size=self.embed_dim,
output_size=self.embed_dim, output_size=self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj",
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@ -259,20 +262,25 @@ class CLIPParallelAttention(nn.Module):
class CLIPMLP(nn.Module): class CLIPMLP(nn.Module):
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size, self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size, self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
@ -284,21 +292,29 @@ class CLIPMLP(nn.Module):
class CLIPEncoderLayer(nn.Module): class CLIPEncoderLayer(nn.Module):
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0: if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(config, self.self_attn = CLIPParallelAttention(
quant_config=quant_config) config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else: else:
self.self_attn = CLIPSdpaAttention(config) self.self_attn = CLIPSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config) self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size, self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@ -327,11 +343,15 @@ class CLIPEncoder(nn.Module):
config: CLIPConfig config: CLIPConfig
""" """
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: CLIPVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
if num_hidden_layers_override is None: if num_hidden_layers_override is None:
@ -339,8 +359,10 @@ class CLIPEncoder(nn.Module):
else: else:
num_hidden_layers = num_hidden_layers_override num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
CLIPEncoderLayer(config=config, quant_config=quant_config) CLIPEncoderLayer(config=config,
for _ in range(num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
]) ])
def forward(self, inputs_embeds: torch.Tensor): def forward(self, inputs_embeds: torch.Tensor):
@ -354,11 +376,17 @@ class CLIPEncoder(nn.Module):
class CLIPVisionTransformer(nn.Module): class CLIPVisionTransformer(nn.Module):
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: CLIPVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
@ -370,19 +398,25 @@ class CLIPVisionTransformer(nn.Module):
self.encoder = CLIPEncoder( self.encoder = CLIPEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override) num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers: if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError( raise ValueError(
f"The original encoder only has {config.num_hidden_layers} " f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers." f"layers, but you requested {len(self.encoder.layers)} layers."
) )
elif len(self.encoder.layers) == config.num_hidden_layers:
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim, self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
else: else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None self.post_layernorm = None
def forward( def forward(
@ -405,10 +439,15 @@ class CLIPVisionModel(nn.Module):
config_class = CLIPVisionConfig config_class = CLIPVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: CLIPVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@ -418,7 +457,10 @@ class CLIPVisionModel(nn.Module):
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override) num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values) return self.vision_model(pixel_values)

View File

@ -113,7 +113,8 @@ class Idefics2VisionAttention(nn.Module):
self, self,
config: Idefics2Config, config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -130,12 +131,14 @@ class Idefics2VisionAttention(nn.Module):
self.head_dim, self.head_dim,
self.num_heads, self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
self.embed_dim, self.embed_dim,
self.embed_dim, self.embed_dim,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj",
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
@ -178,7 +181,8 @@ class Idefics2VisionMLP(nn.Module):
self, self,
config: Idefics2Config, config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
@ -187,12 +191,14 @@ class Idefics2VisionMLP(nn.Module):
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1",
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2",
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -204,13 +210,22 @@ class Idefics2VisionMLP(nn.Module):
class Idefics2EncoderLayer(nn.Module): class Idefics2EncoderLayer(nn.Module):
def __init__(self, config: Idefics2Config): def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(config) self.self_attn = Idefics2VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.layer_norm1 = nn.LayerNorm(self.embed_dim, self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config) self.mlp = Idefics2VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(self.embed_dim, self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@ -245,12 +260,20 @@ class Idefics2Encoder(nn.Module):
config: Idefics2Config config: Idefics2Config
""" """
def __init__(self, config: Idefics2Config): def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Idefics2EncoderLayer(config) Idefics2EncoderLayer(config,
for _ in range(config.num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
]) ])
def forward( def forward(
@ -275,12 +298,20 @@ class Idefics2Encoder(nn.Module):
class Idefics2VisionTransformer(nn.Module): class Idefics2VisionTransformer(nn.Module):
def __init__(self, config: Idefics2VisionConfig): def __init__(
self,
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.config = config self.config = config
self.embeddings = Idefics2VisionEmbeddings(config) self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(config) self.encoder = Idefics2Encoder(config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.post_layernorm = nn.LayerNorm(embed_dim, self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)

View File

@ -137,6 +137,7 @@ class InternParallelAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
*, *,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
@ -165,6 +166,7 @@ class InternParallelAttention(nn.Module):
num_dummy_heads + self.num_heads, num_dummy_heads + self.num_heads,
bias=config.qkv_bias, bias=config.qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv",
) )
self.qk_normalization = config.qk_normalization self.qk_normalization = config.qk_normalization
@ -181,6 +183,7 @@ class InternParallelAttention(nn.Module):
self.dummy_dim, self.dummy_dim,
self.embed_dim, self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.proj",
) )
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
@ -284,20 +287,26 @@ class InternSdpaAttention(nn.Module):
class InternMLP(nn.Module): class InternMLP(nn.Module):
def __init__(self, def __init__(
config: PretrainedConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size, self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size, self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
@ -315,6 +324,7 @@ class InternVisionEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
*, *,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
@ -324,9 +334,12 @@ class InternVisionEncoderLayer(nn.Module):
self.attn = self._init_attn(config, self.attn = self._init_attn(config,
quant_config, quant_config,
num_dummy_heads=num_dummy_heads) num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attn")
self.mlp = InternMLP(config, quant_config=quant_config) self.mlp = InternMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
@ -343,6 +356,7 @@ class InternVisionEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig], quant_config: Optional[QuantizationConfig],
*, *,
num_dummy_heads: int, num_dummy_heads: int,
prefix: str = "",
): ):
# fallback to sdpa attention if tp unavailable # fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@ -351,7 +365,8 @@ class InternVisionEncoderLayer(nn.Module):
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0: if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config, return InternParallelAttention(config,
quant_config=quant_config, quant_config=quant_config,
num_dummy_heads=num_dummy_heads) num_dummy_heads=num_dummy_heads,
prefix=prefix)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
@ -377,6 +392,7 @@ class InternVisionEncoder(nn.Module):
*, *,
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "",
): ):
super().__init__() super().__init__()
@ -390,8 +406,9 @@ class InternVisionEncoder(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
InternVisionEncoderLayer(config, InternVisionEncoderLayer(config,
quant_config, quant_config,
num_dummy_heads=num_dummy_heads) num_dummy_heads=num_dummy_heads,
for _ in range(num_hidden_layers) prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
]) ])
def forward(self, inputs_embeds: torch.Tensor): def forward(self, inputs_embeds: torch.Tensor):
@ -412,7 +429,8 @@ class InternVisionModel(nn.Module):
*, *,
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -423,6 +441,7 @@ class InternVisionModel(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
) )
def get_input_embeddings(self): def get_input_embeddings(self):

View File

@ -19,7 +19,8 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs) token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import (AWQConfig,
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.models.intern_vit import (InternVisionModel, from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
@ -418,11 +419,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config)
image_size = config.force_image_size or config.vision_config.image_size image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size patch_size = config.vision_config.patch_size
self.patch_size = patch_size self.patch_size = patch_size
self.select_layer = config.select_layer
self.num_image_token = int( self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2)) (image_size // patch_size)**2 * (config.downsample_ratio**2))
self.downsample_ratio = config.downsample_ratio self.downsample_ratio = config.downsample_ratio
@ -430,7 +431,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.llm_arch_name = config.text_config.architectures[0] self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
self.vision_model = self._init_vision_model(config, self.is_mono) self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix="vision_model",
)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
@ -441,6 +447,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and \
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")
@cached_property @cached_property
def sampler(self): def sampler(self):
if hasattr(self.language_model, "sampler"): if hasattr(self.language_model, "sampler"):
@ -448,17 +466,28 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return Sampler() return Sampler()
def _init_vision_model(self, config: PretrainedConfig, is_mono: bool): def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
is_mono: bool,
prefix: str,
):
if not is_mono: if not is_mono:
vision_feature_layer = self.select_layer vision_feature_layer = config.select_layer
if vision_feature_layer < 0: if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \ num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1 + vision_feature_layer + 1
else: else:
num_hidden_layers = vision_feature_layer + 1 num_hidden_layers = vision_feature_layer + 1
return InternVisionModel( return InternVisionModel(
config.vision_config, config.vision_config,
num_hidden_layers_override=num_hidden_layers) quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
else: else:
return InternVisionPatchModel(config.vision_config) return InternVisionPatchModel(config.vision_config)

View File

@ -1,12 +1,12 @@
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
TypedDict, Union) Tuple, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
SiglipVisionConfig) PretrainedConfig, SiglipVisionConfig)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
@ -200,7 +200,17 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaConfig): class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: int
def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig,
quant_config: Optional[QuantizationConfig],
*,
require_post_norm: Optional[bool] = None,
):
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
@ -214,16 +224,24 @@ def _init_vision_tower(hf_config: LlavaConfig):
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel( return CLIPVisionModel(
vision_config, vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
) )
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel( return SiglipVisionModel(
vision_config, vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
) )
elif isinstance(vision_config, PixtralVisionConfig): elif isinstance(vision_config, PixtralVisionConfig):
# TODO: allow layer override? return PixtralHFVisionModel(
return PixtralHFVisionModel(vision_config) vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
@ -255,7 +273,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config.projector_hidden_act = "gelu" config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config) self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,

View File

@ -26,7 +26,7 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size, dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import LlavaMultiModalProjector from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
@ -259,32 +259,6 @@ def input_processor_for_llava_next(ctx: InputContext,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaNextConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@ -303,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config) self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(

View File

@ -26,6 +26,7 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip) dummy_seq_data_for_siglip)
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
@ -179,32 +180,6 @@ def input_processor_for_llava_next_video(ctx: InputContext,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaNextVideoConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
# adopted from transformers modeling_llava_next_video.py # adopted from transformers modeling_llava_next_video.py
class LlavaNextVideoPooler(nn.Module): class LlavaNextVideoPooler(nn.Module):
@ -281,7 +256,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config) self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.vision_resampler = LlavaNextVideoPooler(config) self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector( self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,

View File

@ -31,6 +31,7 @@ from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size, dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size, dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
@ -357,32 +358,6 @@ def input_processor_for_llava_onevision(ctx: InputContext,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaOnevisionConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class LlavaOnevisionMultiModalProjector(nn.Module): class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig): def __init__(self, config: LlavaOnevisionConfig):
@ -425,7 +400,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config) self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)

View File

@ -395,7 +395,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.version = get_version_by_config(self.config) self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config, cache_config, quant_config) self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module() self.vpm = self.init_vision_module(config, quant_config)
param_dtype = torch.get_default_dtype() param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype) self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
@ -647,7 +647,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> nn.Module: ) -> nn.Module:
raise NotImplementedError raise NotImplementedError
def init_vision_module(self) -> nn.Module: def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
raise NotImplementedError raise NotImplementedError
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
@ -693,7 +697,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config=quant_config), quant_config=quant_config),
name="model") name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
# TODO :refactor this vision model # TODO :refactor this vision model
try: try:
import timm import timm
@ -817,8 +825,13 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config), quant_config=quant_config),
name="model") name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(
model = Idefics2VisionTransformer(self.config.vision_config) self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
return model return model
@ -929,9 +942,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config), quant_config=quant_config),
name="model") name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(
self,
model = Idefics2VisionTransformer(self.config.vision_config) config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
return model return model

View File

@ -379,9 +379,13 @@ class MllamaVisionSdpaAttention(nn.Module):
class MllamaVisionEncoderLayer(nn.Module): class MllamaVisionEncoderLayer(nn.Module):
def __init__(self, def __init__(
config: config_mllama.MllamaVisionConfig, self,
is_gated: bool = False): config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
is_gated: bool = False,
) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -390,7 +394,9 @@ class MllamaVisionEncoderLayer(nn.Module):
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(config) self.self_attn = MllamaVisionSdpaAttention(config)
self.mlp = CLIPMLP(config) self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = nn.LayerNorm(self.hidden_size, self.input_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps) eps=config.norm_eps)
@ -427,16 +433,23 @@ class MllamaVisionEncoderLayer(nn.Module):
class MllamaVisionEncoder(nn.Module): class MllamaVisionEncoder(nn.Module):
def __init__(self, def __init__(
config: config_mllama.MllamaVisionConfig, self,
num_layers=32, config: config_mllama.MllamaVisionConfig,
is_gated=False, quant_config: Optional[QuantizationConfig],
output_hidden_states=None): num_layers: int = 32,
is_gated: bool = False,
output_hidden_states=None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MllamaVisionEncoderLayer(config, is_gated) MllamaVisionEncoderLayer(config,
for _ in range(num_layers) quant_config=quant_config,
is_gated=is_gated,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_layers)
]) ])
self.output_hidden_states = output_hidden_states or [] self.output_hidden_states = output_hidden_states or []
@ -463,8 +476,14 @@ class MllamaVisionEncoder(nn.Module):
class MllamaVisionModel(nn.Module): class MllamaVisionModel(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig): def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.image_size = config.image_size self.image_size = config.image_size
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles self.max_num_tiles = config.max_num_tiles
@ -500,12 +519,19 @@ class MllamaVisionModel(nn.Module):
# encoders # encoders
self.transformer = MllamaVisionEncoder( self.transformer = MllamaVisionEncoder(
config, config,
quant_config,
config.num_hidden_layers, config.num_hidden_layers,
is_gated=False, is_gated=False,
output_hidden_states=config.intermediate_layers_indices) output_hidden_states=config.intermediate_layers_indices,
self.global_transformer = MllamaVisionEncoder(config, prefix=f"{prefix}.transformer",
config.num_global_layers, )
is_gated=True) self.global_transformer = MllamaVisionEncoder(
config,
quant_config,
config.num_global_layers,
is_gated=True,
prefix=f"{prefix}.global_transformer",
)
def apply_class_embedding(self, def apply_class_embedding(self,
hidden_state: torch.Tensor) -> torch.Tensor: hidden_state: torch.Tensor) -> torch.Tensor:
@ -648,6 +674,7 @@ class MllamaTextCrossAttention(nn.Module):
config: Optional[config_mllama.MllamaTextConfig] = None, config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -673,6 +700,7 @@ class MllamaTextCrossAttention(nn.Module):
self.num_key_value_heads, self.num_key_value_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim, self.num_heads * self.head_dim,
@ -680,6 +708,7 @@ class MllamaTextCrossAttention(nn.Module):
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue, # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead # use huggingface's instead
@ -692,6 +721,7 @@ class MllamaTextCrossAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
self.num_local_key_value_heads, self.num_local_key_value_heads,
prefix=f"{prefix}.attn",
) )
def forward( def forward(
@ -791,15 +821,21 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention """Cross-attention transformer block with tanh-gated attention
and feedforward.""" and feedforward."""
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int, def __init__(
quant_config: Optional[QuantizationConfig]) \ self,
-> None: config: config_mllama.MllamaTextConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention( self.cross_attn = MllamaTextCrossAttention(
config=config, config=config,
layer_idx=layer_idx, layer_idx=layer_idx,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.cross_attn",
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
@ -811,6 +847,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp",
) )
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@ -854,10 +891,15 @@ class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model" base_model_prefix = "model"
def __init__(self, config: config_mllama.MllamaTextConfig, def __init__(
cache_config: Optional[CacheConfig], self,
quant_config: Optional[QuantizationConfig]): config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
@ -869,13 +911,20 @@ class MllamaTextModel(nn.Module):
if layer_idx in self.cross_attention_layers: if layer_idx in self.cross_attention_layers:
layers.append( layers.append(
MllamaCrossAttentionDecoderLayer( MllamaCrossAttentionDecoderLayer(
config, layer_idx, quant_config=quant_config)) config,
layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
else: else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False # TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append( layers.append(
LlamaDecoderLayer(config, LlamaDecoderLayer(
cache_config=cache_config, config,
quant_config=quant_config)) cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
self.layers = nn.ModuleList(layers) self.layers = nn.ModuleList(layers)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -932,12 +981,19 @@ class MllamaForCausalLM(nn.Module):
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
] ]
def __init__(self, config: config_mllama.MllamaTextConfig, def __init__(
cache_config: Optional[CacheConfig], self,
quant_config: Optional[QuantizationConfig]): config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, cache_config, quant_config) self.model = MllamaTextModel(config,
cache_config,
quant_config,
prefix=f"{prefix}.model")
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
@ -994,11 +1050,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
config.pad_token_id if config.pad_token_id is not None else -1 config.pad_token_id if config.pad_token_id is not None else -1
self.image_size = config.vision_config.image_size self.image_size = config.vision_config.image_size
self.vision_model = MllamaVisionModel(config.vision_config) self.vision_model = MllamaVisionModel(config.vision_config,
quant_config)
self.language_model = MllamaForCausalLM( self.language_model = MllamaForCausalLM(
config.text_config, config.text_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix="language_model",
) )
self.multi_modal_projector = nn.Linear( self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim, config.vision_config.vision_output_dim,

View File

@ -4,10 +4,13 @@
# Copyright (c) 2024 NVIDIA # Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details] # Licensed under Apache 2.0 License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from typing import Optional
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
@ -56,9 +59,11 @@ class NVLM_D_Model(InternVLChatModel):
) )
def _init_vision_model(self, config: PretrainedConfig, def _init_vision_model(self, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
num_hidden_layers: int): num_hidden_layers: int):
# We added additional dummy heads to the original num of heads to make # We added additional dummy heads to the original num of heads to make
# the number of heads divisible by 8. # the number of heads divisible by 8.
return InternVisionModel(config.vision_config, return InternVisionModel(config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
num_dummy_heads=7) num_dummy_heads=7)

View File

@ -142,7 +142,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(config.vision_config) self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector( self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim) projection_dim=config.vision_config.projection_dim)

View File

@ -70,7 +70,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim=768) projection_dim=768)
def _init_img_processor(hf_config: PretrainedConfig): def _init_img_processor(hf_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig]):
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
layer_idx = hf_config.img_processor.get('layer_idx', -2) layer_idx = hf_config.img_processor.get('layer_idx', -2)
@ -82,7 +83,10 @@ def _init_img_processor(hf_config: PretrainedConfig):
num_hidden_layers = layer_idx + 1 num_hidden_layers = layer_idx + 1
img_processor = CLIPVisionModel( img_processor = CLIPVisionModel(
clip_config, num_hidden_layers_override=num_hidden_layers) clip_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
)
return img_processor return img_processor
@ -148,14 +152,15 @@ class Phi3ImageEmbeddingBase(nn.Module):
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform.""" """Phi3 Image embedding with HD transform."""
def __init__(self, config: PretrainedConfig) -> None: def __init__(self, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig]) -> None:
super().__init__() super().__init__()
# n_embed or hidden_size # n_embed or hidden_size
hidden_size = config.n_embd if hasattr( hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size config, 'n_embd') else config.hidden_size
self.img_processor = _init_img_processor(config) self.img_processor = _init_img_processor(config, quant_config)
image_dim_out = config.img_processor['image_dim_out'] image_dim_out = config.img_processor['image_dim_out']
self.num_img_tokens = config.img_processor['num_img_tokens'] self.num_img_tokens = config.img_processor['num_img_tokens']
@ -535,7 +540,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) )
# TODO: Optionally initializes this for supporting input embeddings. # TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(config) self.vision_embed_tokens = Phi3HDImageEmbedding(config, quant_config)
self.language_model = LlamaForCausalLM(config, cache_config, self.language_model = LlamaForCausalLM(config, cache_config,
quant_config) quant_config)

View File

@ -767,9 +767,17 @@ def input_processor_for_pixtral_hf(
class PixtralHFMLP(nn.Module): class PixtralHFMLP(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
assert config.intermediate_size is not None assert config.intermediate_size is not None
# TODO: Use quant_config and prefix after optimizing this
self.gate_proj = nn.Linear(config.hidden_size, self.gate_proj = nn.Linear(config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=False) bias=False)
@ -787,8 +795,15 @@ class PixtralHFMLP(nn.Module):
class PixtralHFAttention(nn.Module): class PixtralHFAttention(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
assert not config.hidden_size % config.num_attention_heads assert not config.hidden_size % config.num_attention_heads
self.n_heads = config.num_attention_heads self.n_heads = config.num_attention_heads
@ -796,6 +811,7 @@ class PixtralHFAttention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
# TODO: Use quant_config and prefix after optimizing this
self.q_proj = nn.Linear(config.hidden_size, self.q_proj = nn.Linear(config.hidden_size,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
@ -840,11 +856,22 @@ class PixtralHFAttention(nn.Module):
class PixtralHFTransformerBlock(nn.Module): class PixtralHFTransformerBlock(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention(config) self.attention = PixtralHFAttention(config,
self.feed_forward = PixtralHFMLP(config) quant_config=quant_config,
prefix=f"{prefix}.attention")
self.feed_forward = PixtralHFMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward")
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward( def forward(
@ -864,11 +891,27 @@ class PixtralHFTransformerBlock(nn.Module):
class PixtralHFTransformer(nn.Module): class PixtralHFTransformer(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(config.num_hidden_layers): if num_hidden_layers_override is None:
self.layers.append(PixtralHFTransformerBlock(config)) num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
PixtralHFTransformerBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward( def forward(
self, self,
@ -883,7 +926,15 @@ class PixtralHFTransformer(nn.Module):
class PixtralHFVisionModel(nn.Module): class PixtralHFVisionModel(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -895,7 +946,24 @@ class PixtralHFVisionModel(nn.Module):
bias=False, bias=False,
) )
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(config) self.transformer = PixtralHFTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)
num_hidden_layers = config.num_hidden_layers
if len(self.transformer.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.transformer.layers)} "
"layers.")
if require_post_norm is True:
msg = "PixtralHFVisionModel does not have post-layernorm"
raise ValueError(msg)
self.dtype = next(self.parameters()).dtype self.dtype = next(self.parameters()).dtype
self.device = next(self.parameters()).device self.device = next(self.parameters()).device
self.patch_positional_embedding = PixtralRotaryEmbedding( self.patch_positional_embedding = PixtralRotaryEmbedding(

View File

@ -248,8 +248,10 @@ class SiglipParallelAttention(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
@ -266,12 +268,14 @@ class SiglipParallelAttention(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.num_heads, total_num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
input_size=self.embed_dim, input_size=self.embed_dim,
output_size=self.embed_dim, output_size=self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj",
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@ -314,8 +318,10 @@ class SiglipMLP(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
@ -326,11 +332,13 @@ class SiglipMLP(nn.Module):
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
quant_config=quant_config if quantizable else None, quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc1",
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config if quantizable else None, quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc2",
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -346,15 +354,20 @@ class SiglipEncoderLayer(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0: if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(config, self.self_attn = SiglipParallelAttention(
quant_config=quant_config) config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else: else:
self.self_attn = SiglipSdpaAttention(config) self.self_attn = SiglipSdpaAttention(config)
@ -363,6 +376,7 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp",
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@ -392,8 +406,10 @@ class SiglipEncoder(nn.Module):
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
if num_hidden_layers_override is None: if num_hidden_layers_override is None:
@ -402,8 +418,10 @@ class SiglipEncoder(nn.Module):
num_hidden_layers = num_hidden_layers_override num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
SiglipEncoderLayer(config, quant_config=quant_config) SiglipEncoderLayer(config,
for _ in range(num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
]) ])
def forward( def forward(
@ -424,7 +442,8 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
@ -433,7 +452,9 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
config.hidden_size, config.num_attention_heads, batch_first=True) config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size, self.layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config=config, quant_config=quant_config) self.mlp = SiglipMLP(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size = hidden_state.shape[0] batch_size = hidden_state.shape[0]
@ -454,9 +475,13 @@ class SiglipVisionTransformer(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
): require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
@ -465,26 +490,34 @@ class SiglipVisionTransformer(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
) )
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers: if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError( raise ValueError(
f"The original encoder only has {config.num_hidden_layers} " f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers." f"layers, but you requested {len(self.encoder.layers)} layers."
) )
elif len(self.encoder.layers) == config.num_hidden_layers:
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim, self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
else: else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None self.post_layernorm = None
self.use_head = (True if not hasattr(config, "vision_use_head") else self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head) config.vision_use_head)
if self.use_head: if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead( self.head = SiglipMultiheadAttentionPoolingHead(
config=config, quant_config=quant_config) config=config,
quant_config=quant_config,
prefix=f"{prefix}.head",
)
def forward( def forward(
self, self,
@ -517,8 +550,11 @@ class SiglipVisionModel(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
): require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
@ -529,6 +565,8 @@ class SiglipVisionModel(nn.Module):
config, config,
quant_config, quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
) )
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self) -> nn.Module: