[Bugfix] Fix input processor for InternVL2 model (#7164)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
ab0f5e2823
commit
b764547616
@ -5,6 +5,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
|
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
|
||||||
IMG_START,
|
IMG_START,
|
||||||
@ -26,10 +27,15 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
|
|
||||||
# we use snapshot_download to prevent conflicts between
|
# we use snapshot_download to prevent conflicts between
|
||||||
# dynamic_module and trust_remote_code for hf_runner
|
# dynamic_module and trust_remote_code for hf_runner
|
||||||
|
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
|
||||||
models = [
|
models = [
|
||||||
snapshot_download("OpenGVLab/InternVL2-1B"),
|
snapshot_download("OpenGVLab/InternVL2-1B",
|
||||||
snapshot_download("OpenGVLab/InternVL2-2B"),
|
allow_patterns=DOWNLOAD_PATTERN),
|
||||||
# snapshot_download("OpenGVLab/InternVL2-4B"), # broken
|
snapshot_download("OpenGVLab/InternVL2-2B",
|
||||||
|
allow_patterns=DOWNLOAD_PATTERN),
|
||||||
|
# Broken due to outdated implementation of Phi-3
|
||||||
|
# See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
|
||||||
|
# snapshot_download("OpenGVLab/InternVL2-4B"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -41,8 +47,17 @@ class InternVLProcessor:
|
|||||||
self.tokenizer = hf_runner.tokenizer
|
self.tokenizer = hf_runner.tokenizer
|
||||||
self.dtype = hf_runner.model.dtype
|
self.dtype = hf_runner.model.dtype
|
||||||
|
|
||||||
|
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
|
||||||
|
self.vision_config = self.config.vision_config
|
||||||
|
self.use_thumbnail = self.config.use_thumbnail
|
||||||
|
self.min_num = self.config.min_dynamic_patch
|
||||||
|
self.max_num = self.config.max_dynamic_patch
|
||||||
|
self.image_size = self.vision_config.image_size
|
||||||
|
|
||||||
def __call__(self, text: str, images: Image, **kwargs):
|
def __call__(self, text: str, images: Image, **kwargs):
|
||||||
pixel_values = image_to_pixel_values(images).to(self.dtype)
|
pixel_values = image_to_pixel_values(images, self.image_size,
|
||||||
|
self.min_num, self.max_num,
|
||||||
|
self.use_thumbnail).to(self.dtype)
|
||||||
num_patches_list = [pixel_values.shape[0]]
|
num_patches_list = [pixel_values.shape[0]]
|
||||||
for num_patches in num_patches_list:
|
for num_patches in num_patches_list:
|
||||||
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
|
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
|
||||||
|
|||||||
@ -38,9 +38,6 @@ IMG_CONTEXT = '<IMG_CONTEXT>'
|
|||||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||||
|
|
||||||
MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000
|
|
||||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
|
|
||||||
|
|
||||||
|
|
||||||
class InternVLImagePixelInputs(TypedDict):
|
class InternVLImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
@ -84,11 +81,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
|||||||
return best_ratio
|
return best_ratio
|
||||||
|
|
||||||
|
|
||||||
def calculate_num_blocks(orig_width: int,
|
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||||
orig_height: int,
|
max_num: int,
|
||||||
min_num=1,
|
image_size: int) -> Tuple[int, int, int]:
|
||||||
max_num=6,
|
|
||||||
image_size=448):
|
|
||||||
aspect_ratio = orig_width / orig_height
|
aspect_ratio = orig_width / orig_height
|
||||||
|
|
||||||
# calculate the existing image aspect ratio
|
# calculate the existing image aspect ratio
|
||||||
@ -110,11 +105,9 @@ def calculate_num_blocks(orig_width: int,
|
|||||||
|
|
||||||
|
|
||||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||||
def dynamic_preprocess(image,
|
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
||||||
min_num=1,
|
image_size: int,
|
||||||
max_num=6,
|
use_thumbnail: int) -> List[Image.Image]:
|
||||||
image_size=448,
|
|
||||||
use_thumbnail=False):
|
|
||||||
orig_width, orig_height = image.size
|
orig_width, orig_height = image.size
|
||||||
|
|
||||||
blocks, target_width, target_height = calculate_num_blocks(
|
blocks, target_width, target_height = calculate_num_blocks(
|
||||||
@ -138,12 +131,14 @@ def dynamic_preprocess(image,
|
|||||||
|
|
||||||
|
|
||||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||||
def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6):
|
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
|
||||||
|
max_num: int, use_thumbnail: bool) -> torch.Tensor:
|
||||||
transform = build_transform(input_size=input_size)
|
transform = build_transform(input_size=input_size)
|
||||||
images = dynamic_preprocess(image,
|
images = dynamic_preprocess(image,
|
||||||
|
min_num=min_num,
|
||||||
|
max_num=max_num,
|
||||||
image_size=input_size,
|
image_size=input_size,
|
||||||
use_thumbnail=True,
|
use_thumbnail=use_thumbnail)
|
||||||
max_num=max_num)
|
|
||||||
pixel_values = [transform(image) for image in images]
|
pixel_values = [transform(image) for image in images]
|
||||||
pixel_values = torch.stack(pixel_values)
|
pixel_values = torch.stack(pixel_values)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
@ -159,12 +154,18 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
|
|||||||
def get_max_internvl_image_tokens(ctx: InputContext):
|
def get_max_internvl_image_tokens(ctx: InputContext):
|
||||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
|
use_thumbnail = hf_config.use_thumbnail
|
||||||
|
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||||
|
if use_thumbnail:
|
||||||
|
max_dynamic_patch += 1
|
||||||
|
downsample_ratio = hf_config.downsample_ratio
|
||||||
|
|
||||||
image_size = vision_config.image_size
|
image_size = vision_config.image_size
|
||||||
patch_size = vision_config.patch_size
|
patch_size = vision_config.patch_size
|
||||||
downsample_ratio = hf_config.downsample_ratio
|
|
||||||
num_patches = get_internvl_num_patches(image_size, patch_size,
|
num_patches = get_internvl_num_patches(image_size, patch_size,
|
||||||
downsample_ratio)
|
downsample_ratio)
|
||||||
return num_patches * 7
|
return num_patches * max_dynamic_patch
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||||
@ -176,21 +177,27 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
image_data = multi_modal_data["image"]
|
|
||||||
if isinstance(image_data, Image.Image):
|
|
||||||
width, height = image_data.size
|
|
||||||
num_blocks, _, _ = calculate_num_blocks(width, height)
|
|
||||||
elif isinstance(image_data, torch.Tensor):
|
|
||||||
raise NotImplementedError("Embeddings input is not supported yet")
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
|
||||||
|
|
||||||
image_size = vision_config.image_size
|
image_size = vision_config.image_size
|
||||||
patch_size = vision_config.patch_size
|
patch_size = vision_config.patch_size
|
||||||
downsample_ratio = hf_config.downsample_ratio
|
downsample_ratio = hf_config.downsample_ratio
|
||||||
num_patches = get_internvl_num_patches(image_size, patch_size,
|
num_patches = get_internvl_num_patches(image_size, patch_size,
|
||||||
downsample_ratio)
|
downsample_ratio)
|
||||||
|
|
||||||
|
image_data = multi_modal_data["image"]
|
||||||
|
if isinstance(image_data, Image.Image):
|
||||||
|
width, height = image_data.size
|
||||||
|
min_num = hf_config.min_dynamic_patch
|
||||||
|
max_num = hf_config.max_dynamic_patch
|
||||||
|
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
||||||
|
max_num, image_size)
|
||||||
|
# add thumbnail image if num_blocks > 1
|
||||||
|
if hf_config.use_thumbnail and num_blocks > 1:
|
||||||
|
num_blocks += 1
|
||||||
|
elif isinstance(image_data, torch.Tensor):
|
||||||
|
raise NotImplementedError("Embeddings input is not supported yet")
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
|
|
||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
|
|
||||||
@ -198,8 +205,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = tokenizer.decode(prompt_token_ids)
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
image_prompt = IMG_START + IMG_CONTEXT * (num_blocks +
|
image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END
|
||||||
1) * num_patches + IMG_END
|
|
||||||
new_prompt = prompt.replace('<image>', image_prompt, 1)
|
new_prompt = prompt.replace('<image>', image_prompt, 1)
|
||||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||||
|
|
||||||
@ -209,8 +215,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
|
|
||||||
|
|
||||||
def input_mapper_for_internvl(ctx: InputContext, data: object):
|
def input_mapper_for_internvl(ctx: InputContext, data: object):
|
||||||
|
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||||
|
|
||||||
|
use_thumbnail = hf_config.use_thumbnail
|
||||||
|
min_num = hf_config.min_dynamic_patch
|
||||||
|
max_num = hf_config.max_dynamic_patch
|
||||||
|
image_size = hf_config.vision_config.image_size
|
||||||
|
|
||||||
if isinstance(data, Image.Image):
|
if isinstance(data, Image.Image):
|
||||||
data = image_to_pixel_values(data)
|
data = image_to_pixel_values(data,
|
||||||
|
image_size,
|
||||||
|
min_num,
|
||||||
|
max_num,
|
||||||
|
use_thumbnail=use_thumbnail)
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
@ -240,10 +257,17 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
|
|||||||
add_special_tokens=False)[0],
|
add_special_tokens=False)[0],
|
||||||
image_feature_size_override=image_feature_size,
|
image_feature_size_override=image_feature_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
image_size = vision_config.image_size
|
||||||
|
min_num = hf_config.min_dynamic_patch
|
||||||
|
max_num = hf_config.max_dynamic_patch
|
||||||
|
max_image_width = max_num * image_size
|
||||||
|
max_image_height = min_num * image_size
|
||||||
|
|
||||||
mm_data = dummy_image_for_clip(
|
mm_data = dummy_image_for_clip(
|
||||||
vision_config,
|
vision_config,
|
||||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
image_width_override=max_image_width,
|
||||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
image_height_override=max_image_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
return seq_data, mm_data
|
return seq_data, mm_data
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user