[Model] Expose dynamic_image_size as mm_processor_kwargs for InternVL2 models (#10518)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
8a93a598d9
commit
d5ec121f95
@ -0,0 +1,206 @@
|
||||
"""Tests for InternVL's multimodal preprocessing kwargs."""
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputContext, token_inputs
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
|
||||
from .....conftest import _ImageAssets
|
||||
from ....utils import build_model_context
|
||||
|
||||
models = ["OpenGVLab/InternVL2-2B"]
|
||||
|
||||
|
||||
# Wrap lazy imports to avoid initializing CUDA during test collection
|
||||
@pytest.fixture()
|
||||
def input_processor_for_internvl():
|
||||
from vllm.model_executor.models.internvl import InternVLInputPipeline
|
||||
|
||||
pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
|
||||
return pipeline.input_processor
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_data_for_internvl():
|
||||
from vllm.model_executor.models.internvl import InternVLInputPipeline
|
||||
|
||||
pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
|
||||
return pipeline.dummy_data
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_max_internvl_image_tokens():
|
||||
from vllm.model_executor.models.internvl import (
|
||||
get_max_internvl_image_tokens)
|
||||
return get_max_internvl_image_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
|
||||
def test_input_mapper_override(
|
||||
model: str,
|
||||
image_assets: _ImageAssets,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: Optional[bool],
|
||||
):
|
||||
mm_processor_kwargs = {
|
||||
"max_dynamic_patch": max_dynamic_patch,
|
||||
}
|
||||
if dynamic_image_size is not None:
|
||||
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
|
||||
|
||||
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
|
||||
if dynamic_image_size is False:
|
||||
expected_num_patches = 1
|
||||
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
|
||||
image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
|
||||
vllm_result = mm_registry.map_input(
|
||||
ctx.model_config,
|
||||
{"image": image},
|
||||
)
|
||||
assert vllm_result["pixel_values"].size(1) == expected_num_patches
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
|
||||
def test_max_tokens_override(
|
||||
get_max_internvl_image_tokens: Callable,
|
||||
model: str,
|
||||
max_dynamic_patch: Optional[int],
|
||||
dynamic_image_size: Optional[bool],
|
||||
):
|
||||
"""Ensure get_max_internvl_image_tokens handles mm_processor_kwargs."""
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
|
||||
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
|
||||
if dynamic_image_size is False:
|
||||
expected_num_patches = 1
|
||||
expected_max_tokens = 256 * expected_num_patches
|
||||
|
||||
actual_max_tokens = get_max_internvl_image_tokens(
|
||||
ctx=InputContext(ctx.model_config),
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
assert expected_max_tokens == actual_max_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
|
||||
def test_dummy_data_override(
|
||||
dummy_data_for_internvl: Callable,
|
||||
model: str,
|
||||
num_imgs: int,
|
||||
max_dynamic_patch: Optional[int],
|
||||
dynamic_image_size: Optional[bool],
|
||||
):
|
||||
"""Ensure dummy_data_for_internvl handles kwargs properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the dummy data func.
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
|
||||
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
|
||||
if dynamic_image_size is False:
|
||||
expected_num_patches = 1
|
||||
expected_max_tokens = 256 * expected_num_patches
|
||||
|
||||
dummy_data = dummy_data_for_internvl(
|
||||
ctx=ctx,
|
||||
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
|
||||
mm_counts={"image": num_imgs},
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
sequence_data = dummy_data.seq_data
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
# Ensure we have the right number of placeholders per size
|
||||
img_tok_count = sequence_data.get_token_ids().count(image_token_id)
|
||||
assert img_tok_count == expected_max_tokens * num_imgs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_input_processor_override(
|
||||
input_processor_for_internvl: Callable,
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
num_imgs: int,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: Optional[bool],
|
||||
):
|
||||
"""Ensure input_processor_for_internvl handles kwargs properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the custom input processor.
|
||||
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
|
||||
if dynamic_image_size is False:
|
||||
expected_num_patches = 1
|
||||
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
expected_toks_per_img = 256 * expected_num_patches
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
placeholders = "<image>" if num_imgs == 1 else "\n".join(
|
||||
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
|
||||
prompt = placeholders
|
||||
images = [image_assets[0].pil_image.resize((448 * 2, 448 * 2))] * num_imgs
|
||||
|
||||
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
|
||||
prompt=prompt,
|
||||
multi_modal_data={"image": images})
|
||||
|
||||
processed_inputs = input_processor_for_internvl(
|
||||
ctx,
|
||||
inputs,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
|
||||
add_special_tokens=False)[0]
|
||||
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||
@ -123,8 +123,15 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||
return blocks, target_width, target_height
|
||||
|
||||
|
||||
def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
def calculate_num_blocks_wrapper(
|
||||
hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
@ -183,10 +190,17 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
|
||||
return pixel_values
|
||||
|
||||
|
||||
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
def image_to_pixel_values_wrapper(
|
||||
hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
image_size = hf_config.vision_config.image_size
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
@ -207,11 +221,17 @@ def get_internvl_num_patches(hf_config: PretrainedConfig):
|
||||
(downsample_ratio**2))
|
||||
|
||||
|
||||
def get_max_internvl_image_tokens(ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
def get_max_internvl_image_tokens(
|
||||
ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
@ -222,12 +242,18 @@ def get_max_internvl_image_tokens(ctx: InputContext,
|
||||
return num_patches * max_dynamic_patch
|
||||
|
||||
|
||||
def get_max_internvl_image_size(ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
def get_max_internvl_image_size(
|
||||
ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
image_size = hf_config.vision_config.image_size
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
@ -281,6 +307,7 @@ class InternVLInputPipeline:
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
@ -292,7 +319,7 @@ class InternVLInputPipeline:
|
||||
image_data = multi_modal_data["image"]
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
num_blocks_calculator = calculate_num_blocks_wrapper(
|
||||
hf_config, max_dynamic_patch)
|
||||
hf_config, max_dynamic_patch, dynamic_image_size)
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||
@ -332,11 +359,12 @@ class InternVLInputPipeline:
|
||||
data: object,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||
hf_config, max_dynamic_patch)
|
||||
hf_config, max_dynamic_patch, dynamic_image_size)
|
||||
if isinstance(data, Image.Image):
|
||||
data = image_pixel_values_mapper(data)
|
||||
# Add an N dimension for number of images per prompt (currently 1).
|
||||
@ -366,13 +394,17 @@ class InternVLInputPipeline:
|
||||
mm_counts: Mapping[str, int],
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_feature_size = get_max_internvl_image_tokens(
|
||||
ctx, max_dynamic_patch=max_dynamic_patch)
|
||||
ctx,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
@ -388,7 +420,10 @@ class InternVLInputPipeline:
|
||||
)
|
||||
|
||||
max_image_width, max_image_height = get_max_internvl_image_size(
|
||||
ctx, max_dynamic_patch=max_dynamic_patch)
|
||||
ctx,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(
|
||||
hf_config.vision_config,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user