from typing import List, Optional, Tuple import pytest from transformers import AutoConfig, AutoTokenizer from vllm.model_executor.models.llava_next import ( get_llava_next_image_feature_size) from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs from ..conftest import IMAGE_ASSETS from .utils import check_logprobs_close pytestmark = pytest.mark.vlm _PREFACE = ( "A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's " "questions.") HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": f"{_PREFACE} USER: \nWhat's the content of the image? ASSISTANT:", "cherry_blossom": f"{_PREFACE} USER: \nWhat is the season? ASSISTANT:", }) IMAGE_TOKEN_ID = 32000 def vllm_to_hf_output(vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], model: str): """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output tokenizer = AutoTokenizer.from_pretrained(model) eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID ] assert output_str[0] == " " hf_output_str = output_str[1:] if hf_output_ids[-1] == eos_token_id: hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) return hf_output_ids, hf_output_str, out_logprobs @pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-vicuna-7b-hf"]) @pytest.mark.parametrize( "size_factors", [ # No image [], # Single-scale [1.0], # Single-scale, batched [1.0, 1.0, 1.0], # Multi-scale [0.25, 0.5, 1.0], ], ) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype, max_tokens, num_logprobs) -> None: """Inference result should be the same between hf and vllm. All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding vision language config as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ images = [asset.pil_image for asset in image_assets] inputs_per_image = [( [prompt for _ in size_factors], [rescale_image_size(image, factor) for factor in size_factors], ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] # max_model_len should be greater than image_feature_size with vllm_runner(model, dtype=dtype, max_model_len=4096, enforce_eager=True) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, images=images) for prompts, images in inputs_per_image ] with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, num_logprobs=num_logprobs, images=images) for prompts, images in inputs_per_image ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, vllm_outputs_per_image): # TODO: Check whether using original CLIPVisionModel can improve # consistency against HF check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=[ vllm_to_hf_output(vllm_output, model) for vllm_output in vllm_outputs ], name_0="hf", name_1="vllm", ) @pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144), (183, 488, 776)]) def test_image_feature_size(height_and_width_and_result): height, width, result = height_and_width_and_result config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") assert get_llava_next_image_feature_size(config, input_height=height, input_width=width) == result