[BugFix] Propagate 'trust_remote_code' setting in internvl and minicpmv (#8250)
This commit is contained in:
parent
fc3afc20df
commit
e3dd0692fa
@ -230,8 +230,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
@ -278,8 +279,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
||||
use_thumbnail=use_thumbnail) for img in data
|
||||
]
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
image_token_id = tokenizer.encode(IMG_CONTEXT,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt")[0]
|
||||
@ -298,8 +300,9 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
|
||||
@ -33,6 +33,7 @@ from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
@ -52,6 +53,7 @@ from vllm.model_executor.models.minicpm import MiniCPMModel
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
@ -64,6 +66,17 @@ _KEYS_TO_MODIFY_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class MiniCPMVImageInput(TypedDict):
|
||||
"""Input mapper input with auxiliary data for computing image bounds."""
|
||||
image: Image.Image
|
||||
|
||||
# Image bounds token ids in 0-dim scaler tensor.
|
||||
im_start_id: torch.Tensor
|
||||
im_end_id: torch.Tensor
|
||||
slice_start_id: NotRequired[torch.Tensor]
|
||||
slice_end_id: NotRequired[torch.Tensor]
|
||||
|
||||
|
||||
class MiniCPMVImagePixelInputs(TypedDict):
|
||||
pixel_values: List[torch.Tensor]
|
||||
"""
|
||||
@ -88,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
MiniCPMVImageInputs = MiniCPMVImagePixelInputs
|
||||
|
||||
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
|
||||
@ -234,6 +245,25 @@ class Resampler2_5(BaseResampler):
|
||||
return x
|
||||
|
||||
|
||||
def _build_image_input(ctx: InputContext,
|
||||
image: Image.Image) -> MiniCPMVImageInput:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code)
|
||||
if hasattr(tokenizer, "slice_start_id"):
|
||||
return MiniCPMVImageInput(
|
||||
image=image,
|
||||
im_start_id=torch.tensor(tokenizer.im_start_id),
|
||||
im_end_id=torch.tensor(tokenizer.im_end_id),
|
||||
slice_start_id=torch.tensor(tokenizer.slice_start_id),
|
||||
slice_end_id=torch.tensor(tokenizer.slice_end_id))
|
||||
else:
|
||||
return MiniCPMVImageInput(image=image,
|
||||
im_start_id=torch.tensor(
|
||||
tokenizer.im_start_id),
|
||||
im_end_id=torch.tensor(tokenizer.im_end_id))
|
||||
|
||||
|
||||
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
||||
version_float = getattr(config, "version", None)
|
||||
|
||||
@ -257,10 +287,13 @@ def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
|
||||
return SequenceData.from_token_counts((0, seq_len))
|
||||
|
||||
|
||||
def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
|
||||
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
|
||||
num_images: int):
|
||||
width = height = hf_config.image_size
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
image = _build_image_input(ctx,
|
||||
image=Image.new("RGB", (width, height),
|
||||
color=0))
|
||||
return {"image": [image] if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
|
||||
@ -269,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
|
||||
mm_data = dummy_image_for_minicpmv(hf_config, num_images)
|
||||
mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
@ -280,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return llm_inputs
|
||||
model_config = ctx.model_config
|
||||
version = get_version_by_config(model_config.hf_config)
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
image_processor = cached_get_image_processor(model_config.tokenizer)
|
||||
|
||||
def get_placeholder(image_size: Tuple[int, int], num_image: int):
|
||||
@ -317,6 +351,10 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
new_prompt = "".join(new_prompt_chunks)
|
||||
new_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
multi_modal_data["image"] = [
|
||||
_build_image_input(ctx, image) for image in images
|
||||
]
|
||||
|
||||
llm_inputs = LLMInputs(
|
||||
prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
@ -325,6 +363,32 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return llm_inputs
|
||||
|
||||
|
||||
def input_mapper_for_minicpmv(ctx: InputContext, data: object):
|
||||
model_config = ctx.model_config
|
||||
|
||||
image_processor = cached_get_image_processor(
|
||||
model_config.model, trust_remote_code=model_config.trust_remote_code)
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(
|
||||
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
|
||||
batch_data = image_processor \
|
||||
.preprocess([img["image"] for img in data], return_tensors="pt") \
|
||||
.data
|
||||
|
||||
if len(data) > 0:
|
||||
batch_data["im_start_id"] = data[0]["im_start_id"]
|
||||
batch_data["im_end_id"] = data[0]["im_end_id"]
|
||||
if "slice_start_id" in data[0]:
|
||||
batch_data["slice_start_id"] = data[0]["slice_start_id"]
|
||||
batch_data["slice_end_id"] = data[0]["slice_end_id"]
|
||||
|
||||
return MultiModalInputs(batch_data)
|
||||
|
||||
|
||||
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
"""
|
||||
The abstract class of MiniCPMV can only be inherited, but cannot be
|
||||
@ -365,7 +429,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
def get_embedding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_inputs: Optional[MiniCPMVImageInputs],
|
||||
image_inputs: Optional[MiniCPMVImagePixelInputs],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
|
||||
if hasattr(self.config, "scale_emb"):
|
||||
@ -393,14 +457,20 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
|
||||
return vlm_embedding, vision_hidden_states
|
||||
|
||||
def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
tokenizer = cached_get_tokenizer(self.config._name_or_path,
|
||||
trust_remote_code=True)
|
||||
start_cond = input_ids == tokenizer.im_start_id
|
||||
end_cond = input_ids == tokenizer.im_end_id
|
||||
if hasattr(tokenizer, "slice_start_id"):
|
||||
start_cond |= (input_ids == tokenizer.slice_start_id)
|
||||
end_cond |= (input_ids == tokenizer.slice_end_id)
|
||||
def _get_image_bounds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
im_start_id: torch.Tensor,
|
||||
im_end_id: torch.Tensor,
|
||||
slice_start_id: Optional[torch.Tensor] = None,
|
||||
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
start_cond = input_ids == im_start_id[0]
|
||||
end_cond = input_ids == im_end_id[0]
|
||||
if slice_start_id is not None:
|
||||
start_cond |= (input_ids == slice_start_id[0])
|
||||
end_cond |= (input_ids == slice_end_id[0])
|
||||
|
||||
image_start_tokens, = torch.where(start_cond)
|
||||
image_start_tokens += 1
|
||||
@ -419,7 +489,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
**kwargs: object,
|
||||
) -> Optional[MiniCPMVImageInputs]:
|
||||
) -> Optional[MiniCPMVImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", [])
|
||||
tgt_sizes = kwargs.pop("tgt_sizes", [])
|
||||
|
||||
@ -456,8 +526,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
if len(pixel_values_flat) == 0:
|
||||
return None
|
||||
|
||||
return MiniCPMVImageInputs(
|
||||
image_bounds=self._get_image_bounds(input_ids),
|
||||
im_start_id = kwargs.pop("im_start_id", None)
|
||||
im_end_id = kwargs.pop("im_end_id", None)
|
||||
slice_start_id = kwargs.pop("slice_start_id", None)
|
||||
slice_end_id = kwargs.pop("slice_end_id", None)
|
||||
if im_start_id is None:
|
||||
return None
|
||||
|
||||
return MiniCPMVImagePixelInputs(
|
||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
||||
im_end_id, slice_start_id,
|
||||
slice_end_id),
|
||||
pixel_values=pixel_values_flat,
|
||||
tgt_sizes=torch.stack(tgt_sizes_flat),
|
||||
)
|
||||
@ -564,8 +643,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_default_weight_loading(self, name: str) -> bool:
|
||||
@ -654,8 +733,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
res.append(self.resampler(vision_embedding, tgt_size))
|
||||
return torch.vstack(res)
|
||||
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
|
||||
return self.get_vision_embedding(pixel_values)
|
||||
@ -713,8 +792,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
|
||||
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
||||
return vision_embedding
|
||||
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
tgt_sizes = data["tgt_sizes"]
|
||||
|
||||
@ -807,8 +886,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
).last_hidden_state
|
||||
return vision_embedding
|
||||
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
tgt_sizes = data["tgt_sizes"]
|
||||
|
||||
@ -851,7 +930,7 @@ _SUPPORT_VERSION = {
|
||||
}
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
|
||||
|
||||
@ -674,8 +674,9 @@ def input_processor_for_qwen(ctx: InputContext,
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, torch.Tensor):
|
||||
num_dims = len(image_data.shape)
|
||||
@ -735,8 +736,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
|
||||
return MultiModalInputs()
|
||||
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
|
||||
add_special_tokens=False,
|
||||
@ -824,8 +826,9 @@ def dummy_data_for_qwen(
|
||||
# We have a visual component - use images to warm up
|
||||
num_images = mm_counts["image"]
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
# Build the image prompts with no imgpads; the tokenizer will add img pads
|
||||
image_prompt = ''.join(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user