[Model][VLM] Support multi-images inputs for Phi-3-vision models (#7783)

This commit is contained in:
Isotr0py 2024-08-25 19:51:20 +08:00 committed by GitHub
parent 80162c44b1
commit 8aaf3d5347
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 168 additions and 29 deletions

View File

@ -21,6 +21,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"cherry_blossom":
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
})
HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501
models = ["microsoft/Phi-3.5-vision-instruct"]
@ -184,3 +185,113 @@ def test_regression_7840(hf_runner, vllm_runner, image_assets, model,
num_logprobs=10,
tensor_parallel_size=1,
)
def run_multi_image_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
images: List[Image.Image],
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
inputs_per_case = [
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors])
]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
limit_mm_per_prompt={"image": len(images)},
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_case
]
hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
eos_token_id = hf_model.processor.tokenizer.eos_token_id
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
eos_token_id=eos_token_id)
for prompts, images in inputs_per_case
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
vllm_outputs_per_case):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
run_multi_image_test(
hf_runner,
vllm_runner,
[asset.pil_image for asset in image_assets],
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import re
from functools import lru_cache
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
@ -37,11 +38,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
@ -400,9 +401,20 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
w, h = image_data.size
image_feature_size = get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
image_feature_size = [
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
]
image_data = [image_data]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
w, h = image.size
image_feature_size.append(
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h))
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
@ -410,45 +422,61 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
prompt = llm_inputs.get("prompt")
if prompt is None:
image_idx = []
new_prompt = None
else:
image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
if prompt.count("<|image|>") > 0:
logger.warning("Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating <|image|> tokens.")
elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
elif (num_image_tags := len(image_idx)) > 1:
assert num_image_tags == len(
image_data), "The count of image_placeholder not match image's"
new_prompt = prompt
prompt_token_ids = llm_inputs["prompt_token_ids"]
image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
prompt_token_ids = llm_inputs["prompt_token_ids"].copy()
new_token_ids: List[int] = []
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
new_token_ids.append(_IMAGE_TOKEN_ID)
# masked place_holder with image token id
for idx in image_idx:
image_token_ids = _get_image_placeholder_token_ids(model_config,
idx=idx)
for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids:
prompt_token_ids[i:i + len(image_token_ids)] = [
_IMAGE_TOKEN_ID
] * len(image_token_ids)
break
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
break
# merge consecutive tag ids
merged_token_ids: List[int] = []
for is_placeholder, token_ids in itertools.groupby(
prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
if is_placeholder:
merged_token_ids.append(_IMAGE_TOKEN_ID)
else:
new_token_ids.append(prompt_token_ids[i])
merged_token_ids.extend(list(token_ids))
# TODO: Move this to utils or integrate with clip.
new_token_ids: List[int] = []
placeholder_idx = 0
while merged_token_ids:
token_id = merged_token_ids.pop(0)
if token_id == _IMAGE_TOKEN_ID:
new_token_ids.extend(
repeat_and_pad_token(
_IMAGE_TOKEN_ID,
repeat_count=image_feature_size[placeholder_idx],
))
placeholder_idx += 1
else:
new_token_ids.append(token_id)
# NOTE: Create a defensive copy of the original inputs
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
return input_processor_for_clip(
model_config,
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
llm_inputs,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
return llm_inputs
@MULTIMODAL_REGISTRY.register_image_input_mapper()