[Model][VLM] Support multi-images inputs for Phi-3-vision models (#7783)
This commit is contained in:
parent
80162c44b1
commit
8aaf3d5347
@ -21,6 +21,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
"cherry_blossom":
|
"cherry_blossom":
|
||||||
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
|
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
|
||||||
})
|
})
|
||||||
|
HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501
|
||||||
|
|
||||||
models = ["microsoft/Phi-3.5-vision-instruct"]
|
models = ["microsoft/Phi-3.5-vision-instruct"]
|
||||||
|
|
||||||
@ -184,3 +185,113 @@ def test_regression_7840(hf_runner, vllm_runner, image_assets, model,
|
|||||||
num_logprobs=10,
|
num_logprobs=10,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_image_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
images: List[Image.Image],
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: List[float],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the image fixtures for the test is under tests/images.
|
||||||
|
For huggingface runner, we provide the PIL images as input.
|
||||||
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
|
and corresponding MultiModalConfig as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
|
||||||
|
inputs_per_case = [
|
||||||
|
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||||
|
[[rescale_image_size(image, factor) for image in images]
|
||||||
|
for factor in size_factors])
|
||||||
|
]
|
||||||
|
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
|
# max_model_len should be greater than image_feature_size
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=1,
|
||||||
|
limit_mm_per_prompt={"image": len(images)},
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True) as vllm_model:
|
||||||
|
vllm_outputs_per_case = [
|
||||||
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images)
|
||||||
|
for prompts, images in inputs_per_case
|
||||||
|
]
|
||||||
|
|
||||||
|
hf_model_kwargs = {"_attn_implementation": "eager"}
|
||||||
|
with hf_runner(model, dtype=dtype,
|
||||||
|
model_kwargs=hf_model_kwargs) as hf_model:
|
||||||
|
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
||||||
|
hf_outputs_per_case = [
|
||||||
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images,
|
||||||
|
eos_token_id=eos_token_id)
|
||||||
|
for prompts, images in inputs_per_case
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||||
|
vllm_outputs_per_case):
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=[
|
||||||
|
vllm_to_hf_output(vllm_output, model)
|
||||||
|
for vllm_output in vllm_outputs
|
||||||
|
],
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No image
|
||||||
|
[],
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
# Multi-scale
|
||||||
|
[0.25, 0.5, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||||
|
size_factors, dtype: str, max_tokens: int,
|
||||||
|
num_logprobs: int) -> None:
|
||||||
|
run_multi_image_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
[asset.pil_image for asset in image_assets],
|
||||||
|
model,
|
||||||
|
size_factors=size_factors,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import itertools
|
||||||
import re
|
import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||||
@ -37,11 +38,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel
|
|||||||
from vllm.model_executor.models.llama import LlamaModel
|
from vllm.model_executor.models.llama import LlamaModel
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||||
input_processor_for_clip)
|
|
||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
from .utils import merge_multimodal_embeddings
|
from .utils import merge_multimodal_embeddings
|
||||||
|
|
||||||
@ -400,9 +401,20 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
image_data = multi_modal_data["image"]
|
image_data = multi_modal_data["image"]
|
||||||
if isinstance(image_data, Image.Image):
|
if isinstance(image_data, Image.Image):
|
||||||
w, h = image_data.size
|
w, h = image_data.size
|
||||||
image_feature_size = get_phi3v_image_feature_size(hf_config,
|
image_feature_size = [
|
||||||
input_width=w,
|
get_phi3v_image_feature_size(hf_config,
|
||||||
input_height=h)
|
input_width=w,
|
||||||
|
input_height=h)
|
||||||
|
]
|
||||||
|
image_data = [image_data]
|
||||||
|
elif is_list_of(image_data, Image.Image):
|
||||||
|
image_feature_size = []
|
||||||
|
for image in image_data:
|
||||||
|
w, h = image.size
|
||||||
|
image_feature_size.append(
|
||||||
|
get_phi3v_image_feature_size(hf_config,
|
||||||
|
input_width=w,
|
||||||
|
input_height=h))
|
||||||
elif isinstance(image_data, torch.Tensor):
|
elif isinstance(image_data, torch.Tensor):
|
||||||
image_feature_size = image_data.shape[0]
|
image_feature_size = image_data.shape[0]
|
||||||
else:
|
else:
|
||||||
@ -410,45 +422,61 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
|
|
||||||
prompt = llm_inputs.get("prompt")
|
prompt = llm_inputs.get("prompt")
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
|
image_idx = []
|
||||||
new_prompt = None
|
new_prompt = None
|
||||||
else:
|
else:
|
||||||
|
image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
|
||||||
if prompt.count("<|image|>") > 0:
|
if prompt.count("<|image|>") > 0:
|
||||||
logger.warning("Please follow the prompt format that is "
|
logger.warning("Please follow the prompt format that is "
|
||||||
"documented on HuggingFace which does not involve "
|
"documented on HuggingFace which does not involve "
|
||||||
"repeating <|image|> tokens.")
|
"repeating <|image|> tokens.")
|
||||||
elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
|
elif (num_image_tags := len(image_idx)) > 1:
|
||||||
logger.warning("Multiple image input is not supported yet, "
|
assert num_image_tags == len(
|
||||||
"so any extra image tokens will be treated "
|
image_data), "The count of image_placeholder not match image's"
|
||||||
"as plain text.")
|
|
||||||
|
|
||||||
new_prompt = prompt
|
new_prompt = prompt
|
||||||
|
|
||||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
prompt_token_ids = llm_inputs["prompt_token_ids"].copy()
|
||||||
image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
|
|
||||||
|
|
||||||
new_token_ids: List[int] = []
|
# masked place_holder with image token id
|
||||||
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
|
for idx in image_idx:
|
||||||
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
|
image_token_ids = _get_image_placeholder_token_ids(model_config,
|
||||||
new_token_ids.append(_IMAGE_TOKEN_ID)
|
idx=idx)
|
||||||
|
for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
|
||||||
|
if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids:
|
||||||
|
prompt_token_ids[i:i + len(image_token_ids)] = [
|
||||||
|
_IMAGE_TOKEN_ID
|
||||||
|
] * len(image_token_ids)
|
||||||
|
break
|
||||||
|
|
||||||
# No need to further scan the list since we only replace once
|
# merge consecutive tag ids
|
||||||
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
|
merged_token_ids: List[int] = []
|
||||||
break
|
for is_placeholder, token_ids in itertools.groupby(
|
||||||
|
prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
|
||||||
|
if is_placeholder:
|
||||||
|
merged_token_ids.append(_IMAGE_TOKEN_ID)
|
||||||
else:
|
else:
|
||||||
new_token_ids.append(prompt_token_ids[i])
|
merged_token_ids.extend(list(token_ids))
|
||||||
|
|
||||||
|
# TODO: Move this to utils or integrate with clip.
|
||||||
|
new_token_ids: List[int] = []
|
||||||
|
placeholder_idx = 0
|
||||||
|
while merged_token_ids:
|
||||||
|
token_id = merged_token_ids.pop(0)
|
||||||
|
if token_id == _IMAGE_TOKEN_ID:
|
||||||
|
new_token_ids.extend(
|
||||||
|
repeat_and_pad_token(
|
||||||
|
_IMAGE_TOKEN_ID,
|
||||||
|
repeat_count=image_feature_size[placeholder_idx],
|
||||||
|
))
|
||||||
|
placeholder_idx += 1
|
||||||
|
else:
|
||||||
|
new_token_ids.append(token_id)
|
||||||
|
|
||||||
# NOTE: Create a defensive copy of the original inputs
|
# NOTE: Create a defensive copy of the original inputs
|
||||||
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
|
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
|
||||||
prompt=new_prompt,
|
prompt=new_prompt,
|
||||||
multi_modal_data=multi_modal_data)
|
multi_modal_data=multi_modal_data)
|
||||||
|
return llm_inputs
|
||||||
return input_processor_for_clip(
|
|
||||||
model_config,
|
|
||||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
|
||||||
llm_inputs,
|
|
||||||
image_token_id=_IMAGE_TOKEN_ID,
|
|
||||||
image_feature_size_override=image_feature_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user