[Model] support input image embedding for minicpmv (#9237)

This commit is contained in:
whyiug 2024-10-10 23:00:47 +08:00 committed by GitHub
parent 07c11cf4d4
commit 04de9057ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 101 additions and 43 deletions

View File

@ -378,7 +378,7 @@ Text Generation
- ✅︎ - ✅︎
* - :code:`MiniCPMV` * - :code:`MiniCPMV`
- MiniCPM-V - MiniCPM-V
- Image\ :sup:`+` - Image\ :sup:`E+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎

View File

@ -57,12 +57,19 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
print(generated_text) print(generated_text)
# Inference with image embeddings as input with additional parameters # Inference with image embeddings as input with additional parameters
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding. # Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters.
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM) mm_data = {}
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
# For Qwen2VL, image_grid_thw is needed to calculate positional encoding.
mm_data['image'] = { mm_data['image'] = {
"image_embeds": image_embeds, "image_embeds": image_embeds,
"image_grid_thw": image_grid_thw, "image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3),
}
# For MiniCPM-V, image_size_list is needed to calculate details of the sliced image.
mm_data['image'] = {
"image_embeds": image_embeds,
"image_size_list": [image.size] # list of image sizes
} }
outputs = llm.generate({ outputs = llm.generate({
"prompt": prompt, "prompt": prompt,

View File

@ -24,8 +24,8 @@
import math import math
import re import re
from functools import partial from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
TypedDict) Tuple, TypedDict, Union)
import torch import torch
import torch.types import torch.types
@ -65,10 +65,12 @@ _KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head", "llm.lm_head": "lm_head",
} }
RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVImageInput(TypedDict):
class MiniCPMVRawImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds.""" """Input mapper input with auxiliary data for computing image bounds."""
image: Image.Image image: RawImageType
# Image bounds token ids in 0-dim scaler tensor. # Image bounds token ids in 0-dim scaler tensor.
im_start_id: torch.Tensor im_start_id: torch.Tensor
@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict):
class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImagePixelInputs(TypedDict):
pixel_values: List[torch.Tensor] type: Literal["pixel_values"]
data: List[torch.Tensor]
""" """
Shape: `(batch_size * num_images, num_channels, height, width)` Shape: `(batch_size * num_images, num_channels, height, width)`
@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict):
""" """
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs]
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
@ -194,21 +218,21 @@ class Resampler2_5(BaseResampler):
def _build_image_input(ctx: InputContext, def _build_image_input(ctx: InputContext,
image: Image.Image) -> MiniCPMVImageInput: image: RawImageType) -> MiniCPMVRawImageInput:
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer, ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code) trust_remote_code=ctx.model_config.trust_remote_code)
if hasattr(tokenizer, "slice_start_id"): if hasattr(tokenizer, "slice_start_id"):
return MiniCPMVImageInput( return MiniCPMVRawImageInput(
image=image, image=image,
im_start_id=torch.tensor(tokenizer.im_start_id), im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id), im_end_id=torch.tensor(tokenizer.im_end_id),
slice_start_id=torch.tensor(tokenizer.slice_start_id), slice_start_id=torch.tensor(tokenizer.slice_start_id),
slice_end_id=torch.tensor(tokenizer.slice_end_id)) slice_end_id=torch.tensor(tokenizer.slice_end_id))
else: else:
return MiniCPMVImageInput(image=image, return MiniCPMVRawImageInput(
im_start_id=torch.tensor( image=image,
tokenizer.im_start_id), im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id)) im_end_id=torch.tensor(tokenizer.im_end_id))
@ -280,20 +304,25 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
pattern = "(<image>./</image>)" pattern = "(<image>./</image>)"
images = multi_modal_data["image"] images = multi_modal_data["image"]
if isinstance(images, Image.Image):
images = [images]
image_tags = re.findall(pattern, prompt) image_tags = re.findall(pattern, prompt)
if len(image_tags) == 0: if len(image_tags) == 0:
new_token_ids = token_ids new_token_ids = token_ids
new_prompt = prompt new_prompt = prompt
else: else:
if isinstance(images, dict):
image_size_list = images.get("image_size_list")
images = [images.get("image_embeds")]
else:
if isinstance(images, Image.Image):
images = [images]
image_size_list = [image.size for image in images]
text_chunks = prompt.split(pattern) text_chunks = prompt.split(pattern)
new_prompt_chunks: List[str] = [] new_prompt_chunks: List[str] = []
for i in range(len(images)): for i in range(len(image_size_list)):
new_prompt_chunks += [ new_prompt_chunks += [
text_chunks[i], text_chunks[i],
get_placeholder(images[i].size, i) get_placeholder(image_size_list[i], i)
] ]
new_prompt_chunks.append(text_chunks[-1]) new_prompt_chunks.append(text_chunks[-1])
new_prompt = "".join(new_prompt_chunks) new_prompt = "".join(new_prompt_chunks)
@ -323,6 +352,12 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
if not isinstance(data, list): if not isinstance(data, list):
raise ValueError( raise ValueError(
"Image input must be list of MiniCPMVImageInput, got (%s)", data) "Image input must be list of MiniCPMVImageInput, got (%s)", data)
if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
batch_data = {
"image_embeds": data[0]['image'],
}
else:
batch_data = image_processor \ batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \ .preprocess([img["image"] for img in data], return_tensors="pt") \
.data .data
@ -380,7 +415,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_embedding( def get_embedding(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImagePixelInputs], image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"): if hasattr(self.config, "scale_emb"):
@ -389,7 +424,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if image_inputs is None: # No image if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device) vision_hidden_states = torch.tensor([], device=input_ids.device)
else: else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs) if image_inputs["type"] == "image_embeds":
vision_hidden_states = (image_inputs["data"].type(
vlm_embedding.dtype).to(vlm_embedding.device))
else:
vision_hidden_states = self.get_vision_hidden_states(
image_inputs)
# See NOTE in _parse_and_validate_inputs # See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"] image_bounds = image_inputs["image_bounds"]
@ -440,9 +480,23 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
**kwargs: object, **kwargs: object,
) -> Optional[MiniCPMVImagePixelInputs]: ) -> Optional[MiniCPMVImageInputs]:
pixel_values = kwargs.pop("pixel_values", []) pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", []) tgt_sizes = kwargs.pop("tgt_sizes", [])
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if image_embeds is not None:
return MiniCPMVImageEmbeddingInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=image_embeds,
type="image_embeds",
)
if not isinstance(pixel_values, (torch.Tensor, list)): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
@ -477,10 +531,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if len(pixel_values_flat) == 0: if len(pixel_values_flat) == 0:
return None return None
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
if im_start_id is None: if im_start_id is None:
return None return None
@ -488,8 +538,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
image_bounds=self._get_image_bounds(input_ids, im_start_id, image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id, im_end_id, slice_start_id,
slice_end_id), slice_end_id),
pixel_values=pixel_values_flat, data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat), tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
) )
def forward( def forward(
@ -610,8 +661,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def is_default_weight_loading(self, name: str) -> bool: def is_default_weight_loading(self, name: str) -> bool:
@ -705,9 +756,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
res.append(self.resampler(vision_embedding, tgt_size)) res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res) return torch.vstack(res)
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["data"]
return self.get_vision_embedding(pixel_values) return self.get_vision_embedding(pixel_values)
@ -793,9 +844,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
vision_embedding = self.resampler(vision_embedding, tgt_sizes) vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding return vision_embedding
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"] tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
@ -909,9 +960,9 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
) )
return vision_embedding return vision_embedding
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"] tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device