[Model] SiglipVisionModel ported from transformers (#6942)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
cc08fc7225
commit
c0d8f1636c
@ -65,7 +65,8 @@ def run_phi3v(question):
|
|||||||
# PaliGemma
|
# PaliGemma
|
||||||
def run_paligemma(question):
|
def run_paligemma(question):
|
||||||
|
|
||||||
prompt = question
|
# PaliGemma has special prompt format for VQA
|
||||||
|
prompt = "caption en"
|
||||||
llm = LLM(model="google/paligemma-3b-mix-224")
|
llm = LLM(model="google/paligemma-3b-mix-224")
|
||||||
|
|
||||||
return llm, prompt
|
return llm, prompt
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
|
from transformers import PaliGemmaConfig
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
@ -18,9 +17,11 @@ from vllm.model_executor.models.gemma import GemmaModel
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.image import cached_get_tokenizer
|
from vllm.multimodal.image import cached_get_tokenizer
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsVision
|
||||||
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||||
from .utils import merge_vision_embeddings
|
from .utils import merge_vision_embeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -32,55 +33,22 @@ _KEYS_TO_MODIFY_MAPPING = {
|
|||||||
|
|
||||||
def get_max_paligemma_image_tokens(ctx: InputContext):
|
def get_max_paligemma_image_tokens(ctx: InputContext):
|
||||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||||
text_config = hf_config.text_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
return text_config.num_image_tokens
|
return get_max_siglip_image_tokens(vision_config)
|
||||||
|
|
||||||
|
|
||||||
def dummy_seq_data_for_paligemma(
|
|
||||||
hf_config: PaliGemmaConfig,
|
|
||||||
seq_len: int,
|
|
||||||
*,
|
|
||||||
image_token_id: int,
|
|
||||||
image_feature_size_override: Optional[int] = None,
|
|
||||||
):
|
|
||||||
if image_feature_size_override is None:
|
|
||||||
image_feature_size = hf_config.text_config.num_image_tokens
|
|
||||||
else:
|
|
||||||
image_feature_size = image_feature_size_override
|
|
||||||
|
|
||||||
token_ids = [image_token_id] * image_feature_size
|
|
||||||
token_ids += [0] * (seq_len - image_feature_size)
|
|
||||||
return SequenceData(token_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_image_for_paligemma(
|
|
||||||
hf_config: SiglipVisionConfig,
|
|
||||||
*,
|
|
||||||
image_width_override: Optional[int] = None,
|
|
||||||
image_height_override: Optional[int] = None,
|
|
||||||
):
|
|
||||||
width = height = hf_config.image_size
|
|
||||||
if image_width_override is not None:
|
|
||||||
width = image_width_override
|
|
||||||
if image_height_override is not None:
|
|
||||||
height = image_height_override
|
|
||||||
|
|
||||||
image = Image.new("RGB", (width, height), color=0)
|
|
||||||
return {"image": image}
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
|
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
|
||||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
seq_data = dummy_seq_data_for_paligemma(
|
seq_data = dummy_seq_data_for_siglip(
|
||||||
hf_config,
|
vision_config,
|
||||||
seq_len,
|
seq_len,
|
||||||
image_token_id=hf_config.image_token_index,
|
image_token_id=hf_config.image_token_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_data = dummy_image_for_paligemma(vision_config)
|
mm_data = dummy_image_for_siglip(vision_config)
|
||||||
return seq_data, mm_data
|
return seq_data, mm_data
|
||||||
|
|
||||||
|
|
||||||
@ -208,30 +176,37 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
def _image_pixels_to_features(
|
||||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
self,
|
||||||
|
vision_tower: SiglipVisionModel,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
||||||
image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
|
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
|
||||||
output_hidden_states=True)
|
|
||||||
|
|
||||||
selected_image_features = image_outputs.last_hidden_state
|
return image_features
|
||||||
|
|
||||||
return selected_image_features
|
|
||||||
|
|
||||||
def _process_image_pixels(
|
def _process_image_pixels(
|
||||||
self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor:
|
self,
|
||||||
|
inputs: PaliGemmaImagePixelInputs,
|
||||||
|
) -> torch.Tensor:
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
pixel_values = inputs["data"]
|
pixel_values = inputs["data"]
|
||||||
|
|
||||||
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
return self._image_pixels_to_features(
|
||||||
|
self.vision_tower,
|
||||||
|
pixel_values,
|
||||||
|
)
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self, image_input: PaliGemmaImageInputs) -> torch.Tensor:
|
self,
|
||||||
|
image_input: PaliGemmaImageInputs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
image_features = self._process_image_pixels(image_input)
|
image_features = self._process_image_pixels(image_input, )
|
||||||
|
|
||||||
return self.multi_modal_projector(image_features)
|
return self.multi_modal_projector(image_features)
|
||||||
|
|
||||||
|
|||||||
621
vllm/model_executor/models/siglip.py
Normal file
621
vllm/model_executor/models/siglip.py
Normal file
@ -0,0 +1,621 @@
|
|||||||
|
"""Implementation of SiglipVisionModel intended to be only used
|
||||||
|
within a vision language model."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import nn
|
||||||
|
from transformers import SiglipConfig, SiglipVisionConfig
|
||||||
|
from transformers.models.siglip.modeling_siglip import SiglipAttention
|
||||||
|
from vllm_flash_attn import flash_attn_func
|
||||||
|
from xformers.ops import memory_efficient_attention
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.inputs import LLMInputs
|
||||||
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||||
|
repeat_and_pad_image_tokens)
|
||||||
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
|
|
||||||
|
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||||
|
assert image_size % patch_size == 0
|
||||||
|
return image_size // patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||||
|
grid_length = get_siglip_patch_grid_length(image_size=image_size,
|
||||||
|
patch_size=patch_size)
|
||||||
|
return grid_length * grid_length
|
||||||
|
|
||||||
|
|
||||||
|
def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
|
||||||
|
return get_siglip_num_patches(image_size=hf_config.image_size,
|
||||||
|
patch_size=hf_config.patch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
|
||||||
|
return get_siglip_image_feature_size(hf_config)
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_seq_data_for_siglip(
|
||||||
|
hf_config: SiglipVisionConfig,
|
||||||
|
seq_len: int,
|
||||||
|
*,
|
||||||
|
image_token_id: int,
|
||||||
|
image_feature_size_override: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if image_feature_size_override is None:
|
||||||
|
image_feature_size = get_siglip_image_feature_size(hf_config)
|
||||||
|
else:
|
||||||
|
image_feature_size = image_feature_size_override
|
||||||
|
|
||||||
|
token_ids = [image_token_id] * image_feature_size
|
||||||
|
token_ids += [0] * (seq_len - image_feature_size)
|
||||||
|
return SequenceData(token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_image_for_siglip(
|
||||||
|
hf_config: SiglipVisionConfig,
|
||||||
|
*,
|
||||||
|
image_width_override: Optional[int] = None,
|
||||||
|
image_height_override: Optional[int] = None,
|
||||||
|
):
|
||||||
|
width = height = hf_config.image_size
|
||||||
|
if image_width_override is not None:
|
||||||
|
width = image_width_override
|
||||||
|
if image_height_override is not None:
|
||||||
|
height = image_height_override
|
||||||
|
|
||||||
|
image = Image.new("RGB", (width, height), color=0)
|
||||||
|
return {"image": image}
|
||||||
|
|
||||||
|
|
||||||
|
def input_processor_for_siglip(
|
||||||
|
model_config: ModelConfig,
|
||||||
|
hf_config: SiglipVisionConfig,
|
||||||
|
llm_inputs: LLMInputs,
|
||||||
|
*,
|
||||||
|
image_token_id: int,
|
||||||
|
image_feature_size_override: Optional[int] = None,
|
||||||
|
):
|
||||||
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
|
return llm_inputs
|
||||||
|
|
||||||
|
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||||
|
|
||||||
|
if image_feature_size_override is None:
|
||||||
|
image_feature_size = get_siglip_image_feature_size(hf_config)
|
||||||
|
else:
|
||||||
|
image_feature_size = image_feature_size_override
|
||||||
|
|
||||||
|
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||||
|
tokenizer,
|
||||||
|
llm_inputs.get("prompt"),
|
||||||
|
llm_inputs["prompt_token_ids"],
|
||||||
|
image_token_id=image_token_id,
|
||||||
|
repeat_count=image_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: Create a defensive copy of the original inputs
|
||||||
|
return LLMInputs(
|
||||||
|
prompt_token_ids=new_token_ids,
|
||||||
|
prompt=new_prompt,
|
||||||
|
multi_modal_data=multi_modal_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
|
||||||
|
class SiglipVisionEmbeddings(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: SiglipVisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_patches = (self.image_size // self.patch_size)**2
|
||||||
|
self.num_positions = self.num_patches
|
||||||
|
self.position_embedding = VocabParallelEmbedding(
|
||||||
|
self.num_positions, self.embed_dim)
|
||||||
|
self.register_buffer(
|
||||||
|
"position_ids",
|
||||||
|
torch.arange(self.num_positions, dtype=torch.int64).expand(
|
||||||
|
(1, -1)),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
|
||||||
|
width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method is an adapted method for SigLIP (due to SigLIP not having
|
||||||
|
class embedding unlike other ViTs) that allows the model to interpolate
|
||||||
|
the pre-trained position encodings such that it can be usable on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
position_embeddings = self.position_embedding.weight.unsqueeze(0)
|
||||||
|
num_patches = embeddings.shape[1]
|
||||||
|
num_positions = position_embeddings.shape[1]
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return position_embeddings
|
||||||
|
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
height = height // self.patch_size
|
||||||
|
width = width // self.patch_size
|
||||||
|
# we add a small number to avoid floating point error
|
||||||
|
# in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
height, width = height + 0.1, width + 0.1
|
||||||
|
|
||||||
|
patch_pos_embed = position_embeddings.reshape(
|
||||||
|
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
|
||||||
|
dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(
|
||||||
|
height / math.sqrt(num_positions),
|
||||||
|
width / math.sqrt(num_positions),
|
||||||
|
),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
if (int(height) != patch_pos_embed.shape[-2]
|
||||||
|
or int(width) != patch_pos_embed.shape[-1]):
|
||||||
|
raise ValueError("Width or height does not match with "
|
||||||
|
"the interpolated position embeddings")
|
||||||
|
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return patch_pos_embed
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||||
|
_, _, height, width = pixel_values.shape
|
||||||
|
target_dtype = self.patch_embedding.weight.dtype
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values.to(
|
||||||
|
dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||||
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(
|
||||||
|
embeddings, height, width)
|
||||||
|
else:
|
||||||
|
embeddings = embeddings + self.position_embedding(
|
||||||
|
self.position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: Not used - kept for later when we TP the ViT
|
||||||
|
# TODO(ChristopherCho): Implement TP version of Attention
|
||||||
|
class SiglipTPAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
if self.total_num_heads % tp_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of attention heads ({self.total_num_heads}) "
|
||||||
|
"must be divisible by the tensor model parallel size"
|
||||||
|
f" ({tp_size}).")
|
||||||
|
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.head_dim = self.embed_dim // self.total_num_heads
|
||||||
|
if self.head_dim * self.total_num_heads != self.embed_dim:
|
||||||
|
raise ValueError(f"embed_dim must be divisible by num_heads (got "
|
||||||
|
"`embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads}).")
|
||||||
|
self.qkv_size = self.num_heads * self.head_dim
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=self.embed_dim,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.total_num_heads,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
input_size=self.embed_dim,
|
||||||
|
output_size=self.embed_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn_fn = self._basic_attention_forward
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
batch_size, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv_states.split(
|
||||||
|
[self.qkv_size] * 3, dim=-1)
|
||||||
|
|
||||||
|
attn_output = self.attn_fn(
|
||||||
|
q=query_states,
|
||||||
|
k=key_states,
|
||||||
|
v=value_states,
|
||||||
|
batch_size=batch_size,
|
||||||
|
q_len=q_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output, _ = self.out_proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
def _basic_attention_forward(self, q, k, v, batch_size, q_len):
|
||||||
|
q = q.view(batch_size, q_len, self.num_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, q_len, self.num_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, q_len, self.num_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
k_v_seq_len = k.shape[-2]
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
|
||||||
|
|
||||||
|
if attn_weights.size() != (
|
||||||
|
batch_size,
|
||||||
|
self.num_heads,
|
||||||
|
q_len,
|
||||||
|
k_v_seq_len,
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Attention weights should be of size "
|
||||||
|
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}")
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights,
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.float32).to(q.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights,
|
||||||
|
p=self.dropout,
|
||||||
|
training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
|
if attn_output.size() != (
|
||||||
|
batch_size,
|
||||||
|
self.num_heads,
|
||||||
|
q_len,
|
||||||
|
self.head_dim,
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"`attn_output` should be of size "
|
||||||
|
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}")
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: Not used - kept for later when we TP the ViT
|
||||||
|
# TODO(ChristopherCho): flash_attn_func is not working properly.
|
||||||
|
# It constantly throws a CUDA error.
|
||||||
|
class SiglipFlashAttention2(SiglipTPAttention):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.attn_fn = self._flash_attention_forward
|
||||||
|
|
||||||
|
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
|
||||||
|
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
|
||||||
|
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
|
||||||
|
**kwargs):
|
||||||
|
"""Implements the multihead softmax attention.
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
q, k, v: The tensor containing the
|
||||||
|
query, key, and value. (B, S, H, D)
|
||||||
|
"""
|
||||||
|
|
||||||
|
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||||
|
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
dropout_p=self.dropout,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_len,
|
||||||
|
self.embed_dim).contiguous()
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: Not used - kept for later when we TP the ViT
|
||||||
|
class SiglipSdpaAttention(SiglipTPAttention):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.is_causal = False
|
||||||
|
self.attn_fn = self._sdpa_attention_forward
|
||||||
|
|
||||||
|
def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
|
||||||
|
q = q.view(batch_size, q_len, self.num_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, q_len, self.num_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, q_len, self.num_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: Not used - kept for later when we TP the ViT
|
||||||
|
class SiglipxFormersAttention(SiglipTPAttention):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.attn_fn = self._xformers_attention_forward
|
||||||
|
|
||||||
|
def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
|
||||||
|
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||||
|
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
attn_output = memory_efficient_attention(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
p=0.0,
|
||||||
|
scale=self.scale)
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_len,
|
||||||
|
self.embed_dim).contiguous()
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: Not used - kept for later when we TP the ViT
|
||||||
|
SIGLIP_ATTENTION_CLASSES = {
|
||||||
|
"eager": SiglipTPAttention,
|
||||||
|
"flash_attention_2": SiglipFlashAttention2,
|
||||||
|
"sdpa": SiglipSdpaAttention,
|
||||||
|
"xformers": SiglipxFormersAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = get_act_fn(config.hidden_act)
|
||||||
|
|
||||||
|
# For quantization, we require the hidden size to be a multiple of 64
|
||||||
|
quantizable = (config.hidden_size % 64 == 0
|
||||||
|
and config.intermediate_size % 64 == 0)
|
||||||
|
self.fc1 = ColumnParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
quant_config=quant_config if quantizable else None,
|
||||||
|
)
|
||||||
|
self.fc2 = RowParallelLinear(
|
||||||
|
config.intermediate_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config if quantizable else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states, _ = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipEncoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: SiglipConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
|
# TODO(ChristopherCho): use TP'ed Attention block
|
||||||
|
self.self_attn = SiglipAttention(config)
|
||||||
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
self.mlp = SiglipMLP(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: SiglipConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
SiglipEncoderLayer(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
) for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
) -> Tuple:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
hidden_states, _ = encoder_layer(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||||
|
"""Multihead Attention Pooling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: SiglipVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||||
|
# TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
|
||||||
|
self.attention = torch.nn.MultiheadAttention(
|
||||||
|
config.hidden_size, config.num_attention_heads, batch_first=True)
|
||||||
|
self.layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
self.mlp = SiglipMLP(config=config, quant_config=quant_config)
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = hidden_state.shape[0]
|
||||||
|
probe = self.probe.repeat(batch_size, 1, 1)
|
||||||
|
|
||||||
|
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
||||||
|
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.layernorm(hidden_state)
|
||||||
|
hidden_state = residual + self.mlp(hidden_state)
|
||||||
|
|
||||||
|
return hidden_state[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipVisionTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: SiglipVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
|
||||||
|
self.embeddings = SiglipVisionEmbeddings(config)
|
||||||
|
self.encoder = SiglipEncoder(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
self.use_head = (True if not hasattr(config, "vision_use_head") else
|
||||||
|
config.vision_use_head)
|
||||||
|
if self.use_head:
|
||||||
|
self.head = SiglipMultiheadAttentionPoolingHead(
|
||||||
|
config=config, quant_config=quant_config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
interpolate_pos_encoding: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embeddings(
|
||||||
|
pixel_values,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
||||||
|
|
||||||
|
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||||
|
|
||||||
|
# TODO: add this back when pooled_output is used in inference
|
||||||
|
# if self.use_head:
|
||||||
|
# pooled_output = self.head(last_hidden_state)
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipVisionModel(nn.Module):
|
||||||
|
config_class = SiglipVisionConfig
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: SiglipVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.vision_model = SiglipVisionTransformer(
|
||||||
|
config,
|
||||||
|
quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.vision_model.embeddings.patch_embedding
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue
Block a user