diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index 66cb8dda..6aa01896 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -5,6 +5,7 @@ import pytest import torch from huggingface_hub import snapshot_download from PIL.Image import Image +from transformers import AutoConfig from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END, IMG_START, @@ -26,10 +27,15 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ # we use snapshot_download to prevent conflicts between # dynamic_module and trust_remote_code for hf_runner +DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] models = [ - snapshot_download("OpenGVLab/InternVL2-1B"), - snapshot_download("OpenGVLab/InternVL2-2B"), - # snapshot_download("OpenGVLab/InternVL2-4B"), # broken + snapshot_download("OpenGVLab/InternVL2-1B", + allow_patterns=DOWNLOAD_PATTERN), + snapshot_download("OpenGVLab/InternVL2-2B", + allow_patterns=DOWNLOAD_PATTERN), + # Broken due to outdated implementation of Phi-3 + # See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3 + # snapshot_download("OpenGVLab/InternVL2-4B"), ] @@ -41,8 +47,17 @@ class InternVLProcessor: self.tokenizer = hf_runner.tokenizer self.dtype = hf_runner.model.dtype + self.config = AutoConfig.from_pretrained(hf_runner.model_name) + self.vision_config = self.config.vision_config + self.use_thumbnail = self.config.use_thumbnail + self.min_num = self.config.min_dynamic_patch + self.max_num = self.config.max_dynamic_patch + self.image_size = self.vision_config.image_size + def __call__(self, text: str, images: Image, **kwargs): - pixel_values = image_to_pixel_values(images).to(self.dtype) + pixel_values = image_to_pixel_values(images, self.image_size, + self.min_num, self.max_num, + self.use_thumbnail).to(self.dtype) num_patches_list = [pixel_values.shape[0]] for num_patches in num_patches_list: context_tokens = IMG_CONTEXT * self.num_image_token * num_patches diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 8850fd7c..49f9a4c8 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -38,9 +38,6 @@ IMG_CONTEXT = '' IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) -MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000 -MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500 - class InternVLImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -84,11 +81,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, return best_ratio -def calculate_num_blocks(orig_width: int, - orig_height: int, - min_num=1, - max_num=6, - image_size=448): +def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, + max_num: int, + image_size: int) -> Tuple[int, int, int]: aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio @@ -110,11 +105,9 @@ def calculate_num_blocks(orig_width: int, # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B -def dynamic_preprocess(image, - min_num=1, - max_num=6, - image_size=448, - use_thumbnail=False): +def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, + image_size: int, + use_thumbnail: int) -> List[Image.Image]: orig_width, orig_height = image.size blocks, target_width, target_height = calculate_num_blocks( @@ -138,12 +131,14 @@ def dynamic_preprocess(image, # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B -def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6): +def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, + max_num: int, use_thumbnail: bool) -> torch.Tensor: transform = build_transform(input_size=input_size) images = dynamic_preprocess(image, + min_num=min_num, + max_num=max_num, image_size=input_size, - use_thumbnail=True, - max_num=max_num) + use_thumbnail=use_thumbnail) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values @@ -159,12 +154,18 @@ def get_internvl_num_patches(image_size: int, patch_size: int, def get_max_internvl_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PretrainedConfig) vision_config = hf_config.vision_config + + use_thumbnail = hf_config.use_thumbnail + max_dynamic_patch = hf_config.max_dynamic_patch + if use_thumbnail: + max_dynamic_patch += 1 + downsample_ratio = hf_config.downsample_ratio + image_size = vision_config.image_size patch_size = vision_config.patch_size - downsample_ratio = hf_config.downsample_ratio num_patches = get_internvl_num_patches(image_size, patch_size, downsample_ratio) - return num_patches * 7 + return num_patches * max_dynamic_patch def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): @@ -176,21 +177,27 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): hf_config = ctx.get_hf_config(PretrainedConfig) vision_config = hf_config.vision_config - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - width, height = image_data.size - num_blocks, _, _ = calculate_num_blocks(width, height) - elif isinstance(image_data, torch.Tensor): - raise NotImplementedError("Embeddings input is not supported yet") - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - image_size = vision_config.image_size patch_size = vision_config.patch_size downsample_ratio = hf_config.downsample_ratio num_patches = get_internvl_num_patches(image_size, patch_size, downsample_ratio) + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + width, height = image_data.size + min_num = hf_config.min_dynamic_patch + max_num = hf_config.max_dynamic_patch + num_blocks, _, _ = calculate_num_blocks(width, height, min_num, + max_num, image_size) + # add thumbnail image if num_blocks > 1 + if hf_config.use_thumbnail and num_blocks > 1: + num_blocks += 1 + elif isinstance(image_data, torch.Tensor): + raise NotImplementedError("Embeddings input is not supported yet") + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) @@ -198,8 +205,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): prompt_token_ids = llm_inputs["prompt_token_ids"] if prompt is None: prompt = tokenizer.decode(prompt_token_ids) - image_prompt = IMG_START + IMG_CONTEXT * (num_blocks + - 1) * num_patches + IMG_END + image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END new_prompt = prompt.replace('', image_prompt, 1) new_prompt_token_ids = tokenizer.encode(new_prompt) @@ -209,8 +215,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): def input_mapper_for_internvl(ctx: InputContext, data: object): + hf_config = ctx.get_hf_config(PretrainedConfig) + + use_thumbnail = hf_config.use_thumbnail + min_num = hf_config.min_dynamic_patch + max_num = hf_config.max_dynamic_patch + image_size = hf_config.vision_config.image_size + if isinstance(data, Image.Image): - data = image_to_pixel_values(data) + data = image_to_pixel_values(data, + image_size, + min_num, + max_num, + use_thumbnail=use_thumbnail) model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) @@ -240,10 +257,17 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): add_special_tokens=False)[0], image_feature_size_override=image_feature_size, ) + + image_size = vision_config.image_size + min_num = hf_config.min_dynamic_patch + max_num = hf_config.max_dynamic_patch + max_image_width = max_num * image_size + max_image_height = min_num * image_size + mm_data = dummy_image_for_clip( vision_config, - image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, - image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + image_width_override=max_image_width, + image_height_override=max_image_height, ) return seq_data, mm_data