from typing import List, Optional, Tuple, Type, overload import pytest from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer from vllm.inputs import InputContext from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, _ImageAssets) from ...utils import build_model_context, check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 4 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": "[INST] \nWhat's the content of the image? [/INST]", "cherry_blossom": "[INST] \nWhat is the season? [/INST]", }) models = ["llava-hf/llava-v1.6-mistral-7b-hf"] @pytest.fixture() def get_max_llava_next_image_tokens(): from vllm.model_executor.models.llava_next import ( get_max_llava_next_image_tokens) return get_max_llava_next_image_tokens @pytest.fixture() def dummy_data_for_llava_next(): from vllm.model_executor.models.llava_next import dummy_data_for_llava_next return dummy_data_for_llava_next def vllm_to_hf_output(vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], model: str): """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output config = AutoConfig.from_pretrained(model) image_token_id = config.image_token_index tokenizer = AutoTokenizer.from_pretrained(model) eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] assert output_str[0] == " " hf_output_str = output_str[1:] if hf_output_ids[-1] == eos_token_id: hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) return hf_output_ids, hf_output_str, out_logprobs @overload def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], image_assets: _ImageAssets, model: str, *, size_factors: List[float], dtype: str, max_tokens: int, num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): ... @overload def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], image_assets: _ImageAssets, model: str, *, sizes: List[Tuple[int, int]], dtype: str, max_tokens: int, num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): ... def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], image_assets: _ImageAssets, model: str, *, size_factors: Optional[List[float]] = None, sizes: Optional[List[Tuple[int, int]]] = None, dtype: str, max_tokens: int, num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): images = [asset.pil_image for asset in image_assets] if size_factors is not None: inputs_per_image = [( [prompt for _ in size_factors], [rescale_image_size(image, factor) for factor in size_factors], ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] elif sizes is not None: inputs_per_image = [( [prompt for _ in sizes], [image.resize(size) for size in sizes], ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] else: raise ValueError("You must provide either `size_factors` or `sizes`") _run_test(hf_runner, vllm_runner, inputs_per_image, model, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend) def _run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], inputs: List[Tuple[List[str], PromptImageInput]], model: str, dtype: str, max_tokens: int, num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): # max_model_len should be greater than image_feature_size with vllm_runner(model, dtype=dtype, max_model_len=10240, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True, limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT }) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, images=images) for prompts, images in inputs ] with hf_runner(model, dtype=dtype, auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, num_logprobs=num_logprobs, images=images) for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, vllm_outputs_per_image): # TODO: Check whether using original CLIPVisionModel can improve # consistency against HF check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=[ vllm_to_hf_output(vllm_output, model) for vllm_output in vllm_outputs ], name_0="hf", name_1="vllm", ) @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "size_factors", [ # No image [], # Single-scale [1.0], # Single-scale, batched [1.0, 1.0, 1.0], # Multi-scale [0.25, 0.5, 1.0], ], ) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype, max_tokens, num_logprobs) -> None: """Inference result should be the same between hf and vllm. All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ run_test( hf_runner, vllm_runner, image_assets, model, size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, tensor_parallel_size=1, ) @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "sizes", [[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]], ) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes, dtype, max_tokens, num_logprobs) -> None: run_test( hf_runner, vllm_runner, image_assets, model, sizes=sizes, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, tensor_parallel_size=1, ) @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, model, dtype, max_tokens, num_logprobs) -> None: stop_sign = image_assets[0].pil_image cherry_blossom = image_assets[1].pil_image inputs = [( [ "[INST] \nDescribe 2 images. [/INST]", "[INST] \nDescribe 2 images. [/INST]", "[INST] \nDescribe 4 images. [/INST]", "[INST] \nWhat is the season? [/INST]" ], [ [stop_sign, cherry_blossom], # Images with different sizes and aspect-ratios [ rescale_image_size(stop_sign, 0.1), stop_sign, ], [ stop_sign, rescale_image_size(stop_sign, 0.25), cherry_blossom.resize((183, 488)), cherry_blossom.resize((488, 183)) ], cherry_blossom, ])] _run_test( hf_runner, vllm_runner, inputs, model, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, tensor_parallel_size=1, ) @pytest.mark.parametrize("gridpoints,expected_max_tokens", [ ([[336, 336]], 1176), ([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], 2928), ]) def test_get_max_llava_next_image_tokens(gridpoints, expected_max_tokens, get_max_llava_next_image_tokens): ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf") # Update the config image_grid_pinpoints # and calculate the resulting max tokens ctx.model_config.hf_config.image_grid_pinpoints = gridpoints actual_max_tokens = get_max_llava_next_image_tokens( InputContext(ctx.model_config)) assert expected_max_tokens == actual_max_tokens @pytest.mark.parametrize( "gridpoints,expected_size", [ # One point; it has to be the largest ([[336, 336]], (336, 336)), # Default for most llava next models; the 2x2 tile is the largest ([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], (672, 672)), # If two rectangular gridpoints are the same, the more vertical # one has the higher feature count due to newline features ([[336, 672], [672, 336]], (672, 336)) ]) def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next, gridpoints, expected_size): ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf") # Update the config image_grid_pinpoints ctx.model_config.hf_config.image_grid_pinpoints = gridpoints seq_len = 5000 # bigger than the max feature size for any image seq_data, mm_data = dummy_data_for_llava_next( ctx, seq_len=seq_len, mm_counts={"image": 1}, ) # The dummy data dims should match the gridpoint with the biggest feat size assert mm_data["image"].height == expected_size[0] assert mm_data["image"].width == expected_size[1] assert len(seq_data.get_token_ids()) >= seq_len