[Model] support input image embedding for minicpmv (#9237)
This commit is contained in:
parent
07c11cf4d4
commit
04de9057ab
@ -378,7 +378,7 @@ Text Generation
|
|||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`MiniCPMV`
|
* - :code:`MiniCPMV`
|
||||||
- MiniCPM-V
|
- MiniCPM-V
|
||||||
- Image\ :sup:`+`
|
- Image\ :sup:`E+`
|
||||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
|||||||
@ -57,12 +57,19 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
|
|||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
|
||||||
# Inference with image embeddings as input with additional parameters
|
# Inference with image embeddings as input with additional parameters
|
||||||
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
|
# Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters.
|
||||||
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
|
mm_data = {}
|
||||||
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
|
|
||||||
|
image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
|
||||||
|
# For Qwen2VL, image_grid_thw is needed to calculate positional encoding.
|
||||||
mm_data['image'] = {
|
mm_data['image'] = {
|
||||||
"image_embeds": image_embeds,
|
"image_embeds": image_embeds,
|
||||||
"image_grid_thw": image_grid_thw,
|
"image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3),
|
||||||
|
}
|
||||||
|
# For MiniCPM-V, image_size_list is needed to calculate details of the sliced image.
|
||||||
|
mm_data['image'] = {
|
||||||
|
"image_embeds": image_embeds,
|
||||||
|
"image_size_list": [image.size] # list of image sizes
|
||||||
}
|
}
|
||||||
outputs = llm.generate({
|
outputs = llm.generate({
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
|||||||
@ -24,8 +24,8 @@
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
|
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||||
TypedDict)
|
Tuple, TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.types
|
import torch.types
|
||||||
@ -65,10 +65,12 @@ _KEYS_TO_MODIFY_MAPPING = {
|
|||||||
"llm.lm_head": "lm_head",
|
"llm.lm_head": "lm_head",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RawImageType = Union[Image.Image, torch.Tensor]
|
||||||
|
|
||||||
class MiniCPMVImageInput(TypedDict):
|
|
||||||
|
class MiniCPMVRawImageInput(TypedDict):
|
||||||
"""Input mapper input with auxiliary data for computing image bounds."""
|
"""Input mapper input with auxiliary data for computing image bounds."""
|
||||||
image: Image.Image
|
image: RawImageType
|
||||||
|
|
||||||
# Image bounds token ids in 0-dim scaler tensor.
|
# Image bounds token ids in 0-dim scaler tensor.
|
||||||
im_start_id: torch.Tensor
|
im_start_id: torch.Tensor
|
||||||
@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class MiniCPMVImagePixelInputs(TypedDict):
|
class MiniCPMVImagePixelInputs(TypedDict):
|
||||||
pixel_values: List[torch.Tensor]
|
type: Literal["pixel_values"]
|
||||||
|
data: List[torch.Tensor]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||||
|
|
||||||
@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MiniCPMVImageEmbeddingInputs(TypedDict):
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""
|
||||||
|
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||||
|
|
||||||
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
|
instead of a batched tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_bounds: torch.Tensor
|
||||||
|
"""
|
||||||
|
Shape: `(batch_size * num_images, 2)`
|
||||||
|
|
||||||
|
This should be in `(start, stop)` format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
|
||||||
|
MiniCPMVImageEmbeddingInputs]
|
||||||
|
|
||||||
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
|
|
||||||
@ -194,22 +218,22 @@ class Resampler2_5(BaseResampler):
|
|||||||
|
|
||||||
|
|
||||||
def _build_image_input(ctx: InputContext,
|
def _build_image_input(ctx: InputContext,
|
||||||
image: Image.Image) -> MiniCPMVImageInput:
|
image: RawImageType) -> MiniCPMVRawImageInput:
|
||||||
tokenizer = cached_get_tokenizer(
|
tokenizer = cached_get_tokenizer(
|
||||||
ctx.model_config.tokenizer,
|
ctx.model_config.tokenizer,
|
||||||
trust_remote_code=ctx.model_config.trust_remote_code)
|
trust_remote_code=ctx.model_config.trust_remote_code)
|
||||||
if hasattr(tokenizer, "slice_start_id"):
|
if hasattr(tokenizer, "slice_start_id"):
|
||||||
return MiniCPMVImageInput(
|
return MiniCPMVRawImageInput(
|
||||||
image=image,
|
image=image,
|
||||||
im_start_id=torch.tensor(tokenizer.im_start_id),
|
im_start_id=torch.tensor(tokenizer.im_start_id),
|
||||||
im_end_id=torch.tensor(tokenizer.im_end_id),
|
im_end_id=torch.tensor(tokenizer.im_end_id),
|
||||||
slice_start_id=torch.tensor(tokenizer.slice_start_id),
|
slice_start_id=torch.tensor(tokenizer.slice_start_id),
|
||||||
slice_end_id=torch.tensor(tokenizer.slice_end_id))
|
slice_end_id=torch.tensor(tokenizer.slice_end_id))
|
||||||
else:
|
else:
|
||||||
return MiniCPMVImageInput(image=image,
|
return MiniCPMVRawImageInput(
|
||||||
im_start_id=torch.tensor(
|
image=image,
|
||||||
tokenizer.im_start_id),
|
im_start_id=torch.tensor(tokenizer.im_start_id),
|
||||||
im_end_id=torch.tensor(tokenizer.im_end_id))
|
im_end_id=torch.tensor(tokenizer.im_end_id))
|
||||||
|
|
||||||
|
|
||||||
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
||||||
@ -280,20 +304,25 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
|
|
||||||
pattern = "(<image>./</image>)"
|
pattern = "(<image>./</image>)"
|
||||||
images = multi_modal_data["image"]
|
images = multi_modal_data["image"]
|
||||||
if isinstance(images, Image.Image):
|
|
||||||
images = [images]
|
|
||||||
image_tags = re.findall(pattern, prompt)
|
image_tags = re.findall(pattern, prompt)
|
||||||
|
|
||||||
if len(image_tags) == 0:
|
if len(image_tags) == 0:
|
||||||
new_token_ids = token_ids
|
new_token_ids = token_ids
|
||||||
new_prompt = prompt
|
new_prompt = prompt
|
||||||
else:
|
else:
|
||||||
|
if isinstance(images, dict):
|
||||||
|
image_size_list = images.get("image_size_list")
|
||||||
|
images = [images.get("image_embeds")]
|
||||||
|
else:
|
||||||
|
if isinstance(images, Image.Image):
|
||||||
|
images = [images]
|
||||||
|
image_size_list = [image.size for image in images]
|
||||||
|
|
||||||
text_chunks = prompt.split(pattern)
|
text_chunks = prompt.split(pattern)
|
||||||
new_prompt_chunks: List[str] = []
|
new_prompt_chunks: List[str] = []
|
||||||
for i in range(len(images)):
|
for i in range(len(image_size_list)):
|
||||||
new_prompt_chunks += [
|
new_prompt_chunks += [
|
||||||
text_chunks[i],
|
text_chunks[i],
|
||||||
get_placeholder(images[i].size, i)
|
get_placeholder(image_size_list[i], i)
|
||||||
]
|
]
|
||||||
new_prompt_chunks.append(text_chunks[-1])
|
new_prompt_chunks.append(text_chunks[-1])
|
||||||
new_prompt = "".join(new_prompt_chunks)
|
new_prompt = "".join(new_prompt_chunks)
|
||||||
@ -323,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
|
|||||||
if not isinstance(data, list):
|
if not isinstance(data, list):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
|
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
|
||||||
batch_data = image_processor \
|
|
||||||
.preprocess([img["image"] for img in data], return_tensors="pt") \
|
if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
|
||||||
.data
|
batch_data = {
|
||||||
|
"image_embeds": data[0]['image'],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
batch_data = image_processor \
|
||||||
|
.preprocess([img["image"] for img in data], return_tensors="pt") \
|
||||||
|
.data
|
||||||
|
|
||||||
if len(data) > 0:
|
if len(data) > 0:
|
||||||
batch_data["im_start_id"] = data[0]["im_start_id"]
|
batch_data["im_start_id"] = data[0]["im_start_id"]
|
||||||
@ -380,7 +415,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
def get_embedding(
|
def get_embedding(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
image_inputs: Optional[MiniCPMVImagePixelInputs],
|
image_inputs: Optional[MiniCPMVImageInputs],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
|
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
|
||||||
if hasattr(self.config, "scale_emb"):
|
if hasattr(self.config, "scale_emb"):
|
||||||
@ -389,7 +424,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if image_inputs is None: # No image
|
if image_inputs is None: # No image
|
||||||
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
||||||
else:
|
else:
|
||||||
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
if image_inputs["type"] == "image_embeds":
|
||||||
|
vision_hidden_states = (image_inputs["data"].type(
|
||||||
|
vlm_embedding.dtype).to(vlm_embedding.device))
|
||||||
|
else:
|
||||||
|
vision_hidden_states = self.get_vision_hidden_states(
|
||||||
|
image_inputs)
|
||||||
|
|
||||||
# See NOTE in _parse_and_validate_inputs
|
# See NOTE in _parse_and_validate_inputs
|
||||||
image_bounds = image_inputs["image_bounds"]
|
image_bounds = image_inputs["image_bounds"]
|
||||||
@ -440,9 +480,23 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> Optional[MiniCPMVImagePixelInputs]:
|
) -> Optional[MiniCPMVImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", [])
|
pixel_values = kwargs.pop("pixel_values", [])
|
||||||
tgt_sizes = kwargs.pop("tgt_sizes", [])
|
tgt_sizes = kwargs.pop("tgt_sizes", [])
|
||||||
|
im_start_id = kwargs.pop("im_start_id", None)
|
||||||
|
im_end_id = kwargs.pop("im_end_id", None)
|
||||||
|
slice_start_id = kwargs.pop("slice_start_id", None)
|
||||||
|
slice_end_id = kwargs.pop("slice_end_id", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
|
if image_embeds is not None:
|
||||||
|
return MiniCPMVImageEmbeddingInputs(
|
||||||
|
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
||||||
|
im_end_id, slice_start_id,
|
||||||
|
slice_end_id),
|
||||||
|
data=image_embeds,
|
||||||
|
type="image_embeds",
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
@ -477,10 +531,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if len(pixel_values_flat) == 0:
|
if len(pixel_values_flat) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
im_start_id = kwargs.pop("im_start_id", None)
|
|
||||||
im_end_id = kwargs.pop("im_end_id", None)
|
|
||||||
slice_start_id = kwargs.pop("slice_start_id", None)
|
|
||||||
slice_end_id = kwargs.pop("slice_end_id", None)
|
|
||||||
if im_start_id is None:
|
if im_start_id is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -488,8 +538,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
||||||
im_end_id, slice_start_id,
|
im_end_id, slice_start_id,
|
||||||
slice_end_id),
|
slice_end_id),
|
||||||
pixel_values=pixel_values_flat,
|
data=pixel_values_flat,
|
||||||
tgt_sizes=torch.stack(tgt_sizes_flat),
|
tgt_sizes=torch.stack(tgt_sizes_flat),
|
||||||
|
type="pixel_values",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -610,8 +661,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_vision_hidden_states(
|
def get_vision_hidden_states(self,
|
||||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def is_default_weight_loading(self, name: str) -> bool:
|
def is_default_weight_loading(self, name: str) -> bool:
|
||||||
@ -705,9 +756,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
|||||||
res.append(self.resampler(vision_embedding, tgt_size))
|
res.append(self.resampler(vision_embedding, tgt_size))
|
||||||
return torch.vstack(res)
|
return torch.vstack(res)
|
||||||
|
|
||||||
def get_vision_hidden_states(
|
def get_vision_hidden_states(self,
|
||||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||||
pixel_values = data["pixel_values"]
|
pixel_values = data["data"]
|
||||||
|
|
||||||
return self.get_vision_embedding(pixel_values)
|
return self.get_vision_embedding(pixel_values)
|
||||||
|
|
||||||
@ -793,9 +844,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
||||||
return vision_embedding
|
return vision_embedding
|
||||||
|
|
||||||
def get_vision_hidden_states(
|
def get_vision_hidden_states(self,
|
||||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||||
pixel_values = data["pixel_values"]
|
pixel_values = data["data"]
|
||||||
tgt_sizes = data["tgt_sizes"]
|
tgt_sizes = data["tgt_sizes"]
|
||||||
|
|
||||||
device = self.vpm.embeddings.position_embedding.weight.device
|
device = self.vpm.embeddings.position_embedding.weight.device
|
||||||
@ -909,9 +960,9 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
)
|
)
|
||||||
return vision_embedding
|
return vision_embedding
|
||||||
|
|
||||||
def get_vision_hidden_states(
|
def get_vision_hidden_states(self,
|
||||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||||
pixel_values = data["pixel_values"]
|
pixel_values = data["data"]
|
||||||
tgt_sizes = data["tgt_sizes"]
|
tgt_sizes = data["tgt_sizes"]
|
||||||
|
|
||||||
device = self.vpm.embeddings.position_embedding.weight.device
|
device = self.vpm.embeddings.position_embedding.weight.device
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user