[CI/Build][VLM] Cleanup multiple images inputs model test (#7897)
This commit is contained in:
parent
6fc4e6e07a
commit
9db642138b
@ -1,14 +1,15 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL import Image
|
||||
from transformers import BatchEncoding
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
@ -24,6 +25,11 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
})
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = \
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
|
||||
"(<image>./</image>)\n(<image>./</image>)\n" \
|
||||
"Describe these images.<|eot_id|>" \
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
|
||||
|
||||
@ -46,13 +52,14 @@ target_dtype = "half"
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
inputs: List[Tuple[List[str], Union[List[Image.Image],
|
||||
List[List[Image.Image]]]]],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
@ -65,12 +72,6 @@ def run_test(
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
@ -82,6 +83,7 @@ def run_test(
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
@ -93,7 +95,7 @@ def run_test(
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
stop_token_ids=stop_token_ids)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
|
||||
@ -104,7 +106,7 @@ def run_test(
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
tokenizer=tokenizer)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
@ -138,104 +140,26 @@ def run_test(
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
inputs_per_image,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = \
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
|
||||
"(<image>./</image>)\n(<image>./</image>)\n" \
|
||||
"Describe these images.<|eot_id|>" \
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
|
||||
def run_multi_image_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case = [
|
||||
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors])
|
||||
]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
limit_mm_per_prompt={"image": len(images)},
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
stop_token_ids=stop_token_ids)
|
||||
for prompts, images in inputs_per_case
|
||||
]
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
|
||||
with hf_model, torch.no_grad():
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
tokenizer=tokenizer)
|
||||
for prompts, images in inputs_per_case
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||
vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=[
|
||||
trunc_hf_output(hf_output) for hf_output in hf_outputs
|
||||
],
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
@ -256,14 +180,22 @@ def run_multi_image_test(
|
||||
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
run_multi_image_test(
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case = [
|
||||
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors])
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
inputs_per_case,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
@ -60,13 +60,14 @@ if is_hip():
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
images: List[Image.Image],
|
||||
inputs: List[Tuple[List[str], Union[List[Image.Image],
|
||||
List[List[Image.Image]]]]],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
@ -79,13 +80,6 @@ def run_test(
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[
|
||||
rescale_image_size(image, factor, transpose=idx)
|
||||
for idx, factor in enumerate(size_factors)
|
||||
],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
@ -97,15 +91,16 @@ def run_test(
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
||||
@ -113,17 +108,17 @@ def run_test(
|
||||
with hf_runner(model, dtype=dtype,
|
||||
model_kwargs=hf_model_kwargs) as hf_model:
|
||||
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
||||
hf_outputs_per_image = [
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||
vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
@ -156,15 +151,22 @@ def run_test(
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
[asset.pil_image for asset in image_assets],
|
||||
inputs_per_image,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -173,97 +175,26 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
def test_regression_7840(hf_runner, vllm_runner, image_assets, model,
|
||||
dtype) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_regresion_7840 = [
|
||||
([prompt], [image]) for image, prompt in zip(images, HF_IMAGE_PROMPTS)
|
||||
]
|
||||
|
||||
# Regression test for #7840.
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
[image_assets[0].pil_image.resize((465, 226))],
|
||||
inputs_regresion_7840,
|
||||
model,
|
||||
size_factors=[1.0],
|
||||
dtype=dtype,
|
||||
max_tokens=128,
|
||||
num_logprobs=10,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
def run_multi_image_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
images: List[Image.Image],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
|
||||
inputs_per_case = [
|
||||
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors])
|
||||
]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
limit_mm_per_prompt={"image": len(images)},
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_case
|
||||
]
|
||||
|
||||
hf_model_kwargs = {"_attn_implementation": "eager"}
|
||||
with hf_runner(model, dtype=dtype,
|
||||
model_kwargs=hf_model_kwargs) as hf_model:
|
||||
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, images in inputs_per_case
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||
vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, model)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
@ -280,18 +211,26 @@ def run_multi_image_test(
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
run_multi_image_test(
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case = [
|
||||
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors])
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
[asset.pil_image for asset in image_assets],
|
||||
inputs_per_case,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user