[Model] Support E5-V (#9576)

This commit is contained in:
Cyrus Leung 2024-10-23 11:35:29 +08:00 committed by GitHub
parent 29061ed9df
commit 831540cf04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 528 additions and 86 deletions

View File

@ -334,6 +334,14 @@ The following modalities are supported depending on the model:
- **V**\ ideo - **V**\ ideo
- **A**\ udio - **A**\ udio
Any combination of modalities joined by :code:`+` are supported.
- e.g.: :code:`T + I` means that the model supports text-only, image-only, and text-with-image inputs.
On the other hand, modalities separated by :code:`/` are mutually exclusive.
- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.
.. _supported_vlms: .. _supported_vlms:
Text Generation Text Generation
@ -484,6 +492,12 @@ Multimodal Embedding
- Example HF Models - Example HF Models
- :ref:`LoRA <lora>` - :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>` - :ref:`PP <distributed_serving>`
* - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT-based
- T / I
- :code:`royokong/e5-v`
-
- ✅︎
* - :code:`Phi3VForCausalLM` * - :code:`Phi3VForCausalLM`
- Phi-3-Vision-based - Phi-3-Vision-based
- T + I - T + I

View File

@ -1,6 +1,6 @@
""" """
This example shows how to use vLLM for running offline inference This example shows how to use vLLM for running offline inference with
with the correct prompt format on vision language models. the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples For most models, the prompt format should follow corresponding examples
on HuggingFace model repository. on HuggingFace model repository.
@ -450,7 +450,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description='Demo on using vLLM for offline inference with '
'vision language models') 'vision language models for text generation')
parser.add_argument('--model-type', parser.add_argument('--model-type',
'-m', '-m',
type=str, type=str,

View File

@ -1,22 +1,170 @@
"""
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for multimodal embedding.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from argparse import Namespace
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
from PIL.Image import Image
from vllm import LLM from vllm import LLM
from vllm.assets.image import ImageAsset from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501
# Create an LLM. class TextQuery(TypedDict):
llm = LLM( modality: Literal["text"]
model="TIGER-Lab/VLM2Vec-Full", text: str
task="embedding",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
mm_processor_kwargs={"num_crops": 16},
)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.encode({"prompt": prompt, "multi_modal_data": {"image": image}})
# Print the outputs. class ImageQuery(TypedDict):
for output in outputs: modality: Literal["image"]
print(output.outputs.embedding) # list of 3072 floats image: Image
class TextImageQuery(TypedDict):
modality: Literal["text+image"]
text: str
image: Image
QueryModality = Literal["text", "image", "text+image"]
Query = Union[TextQuery, ImageQuery, TextImageQuery]
class ModelRequestData(NamedTuple):
llm: LLM
prompt: str
image: Optional[Image]
def run_e5_v(query: Query):
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
if query["modality"] == "text":
text = query["text"]
prompt = llama3_template.format(
f"{text}\nSummary above sentence in one word: ")
image = None
elif query["modality"] == "image":
prompt = llama3_template.format(
"<image>\nSummary above image in one word: ")
image = query["image"]
else:
modality = query['modality']
raise ValueError(f"Unsupported query modality: '{modality}'")
llm = LLM(
model="royokong/e5-v",
task="embedding",
max_model_len=4096,
)
return ModelRequestData(
llm=llm,
prompt=prompt,
image=image,
)
def run_vlm2vec(query: Query):
if query["modality"] == "text":
text = query["text"]
prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501
image = None
elif query["modality"] == "image":
prompt = "<|image_1|> Find a day-to-day image that looks similar to the provided image." # noqa: E501
image = query["image"]
elif query["modality"] == "text+image":
text = query["text"]
prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
image = query["image"]
else:
modality = query['modality']
raise ValueError(f"Unsupported query modality: '{modality}'")
llm = LLM(
model="TIGER-Lab/VLM2Vec-Full",
task="embedding",
trust_remote_code=True,
mm_processor_kwargs={"num_crops": 4},
)
return ModelRequestData(
llm=llm,
prompt=prompt,
image=image,
)
def get_query(modality: QueryModality):
if modality == "text":
return TextQuery(modality="text", text="A dog sitting in the grass")
if modality == "image":
return ImageQuery(
modality="image",
image=fetch_image(
"https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg" # noqa: E501
),
)
if modality == "text+image":
return TextImageQuery(
modality="text+image",
text="A cat standing in the snow.",
image=fetch_image(
"https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg" # noqa: E501
),
)
msg = f"Modality {modality} is not supported."
raise ValueError(msg)
def run_encode(model: str, modality: QueryModality):
query = get_query(modality)
req_data = model_example_map[model](query)
mm_data = {}
if req_data.image is not None:
mm_data["image"] = req_data.image
outputs = req_data.llm.encode({
"prompt": req_data.prompt,
"multi_modal_data": mm_data,
})
for output in outputs:
print(output.outputs.embedding)
def main(args: Namespace):
run_encode(args.model_name, args.modality)
model_example_map = {
"e5_v": run_e5_v,
"vlm2vec": run_vlm2vec,
}
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models for multimodal embedding')
parser.add_argument('--model-name',
'-m',
type=str,
default="vlm2vec",
choices=model_example_map.keys(),
help='The name of the embedding model.')
parser.add_argument('--modality',
type=str,
default="image",
choices=get_args(QueryModality),
help='Modality of the input.')
args = parser.parse_args()
main(args)

View File

@ -1,7 +1,7 @@
""" """
This example shows how to use vLLM for running offline inference with This example shows how to use vLLM for running offline inference with
multi-image input on vision language models, using the chat template defined multi-image input on vision language models for text generation,
by the model. using the chat template defined by the model.
""" """
from argparse import Namespace from argparse import Namespace
from typing import List, NamedTuple, Optional from typing import List, NamedTuple, Optional
@ -334,7 +334,8 @@ def main(args: Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description='Demo on using vLLM for offline inference with '
'vision language models that support multi-image input') 'vision language models that support multi-image input for text '
'generation')
parser.add_argument('--model-type', parser.add_argument('--model-type',
'-m', '-m',
type=str, type=str,

View File

@ -43,10 +43,12 @@ _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]] _M = TypeVar("_M")
PromptAudioInput = Union[List[Tuple[np.ndarray, int]], _PromptMultiModalInput = Union[List[_M], List[List[_M]]]
List[List[Tuple[np.ndarray, int]]]]
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]] PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
def _read_prompts(filename: str) -> List[str]: def _read_prompts(filename: str) -> List[str]:
@ -318,12 +320,12 @@ class HfRunner:
"text": prompt, "text": prompt,
"return_tensors": "pt", "return_tensors": "pt",
} }
if images is not None and images[i] is not None: if images is not None and (image := images[i]) is not None:
processor_kwargs["images"] = images[i] processor_kwargs["images"] = image
if videos is not None and videos[i] is not None: if videos is not None and (video := videos[i]) is not None:
processor_kwargs["videos"] = videos[i] processor_kwargs["videos"] = video
if audios is not None and audios[i] is not None: if audios is not None and (audio_tuple := audios[i]) is not None:
audio, sr = audios[i] audio, sr = audio_tuple
processor_kwargs["audio"] = audio processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr processor_kwargs["sampling_rate"] = sr
@ -338,7 +340,7 @@ class HfRunner:
self, self,
prompts: List[str], prompts: List[str],
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None, videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[Tuple[List[List[int]], List[str]]]:
@ -368,7 +370,7 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None, videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
@ -409,7 +411,7 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None, videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[List[torch.Tensor]]: ) -> List[List[torch.Tensor]]:
@ -488,7 +490,7 @@ class HfRunner:
num_logprobs: int, num_logprobs: int,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
videos: Optional[List[np.ndarray]] = None, videos: Optional[PromptVideoInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[TokensTextLogprobs]: ) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts, all_inputs = self.get_inputs(prompts,
@ -657,15 +659,18 @@ class VllmRunner:
inputs = [TextPrompt(prompt=prompt) for prompt in prompts] inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None: if images is not None:
for i, image in enumerate(images): for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image} if image is not None:
inputs[i]["multi_modal_data"] = {"image": image}
if videos is not None: if videos is not None:
for i, video in enumerate(videos): for i, video in enumerate(videos):
inputs[i]["multi_modal_data"] = {"video": video} if video is not None:
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None: if audios is not None:
for i, audio in enumerate(audios): for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio} if audio is not None:
inputs[i]["multi_modal_data"] = {"audio": audio}
return inputs return inputs
@ -837,13 +842,20 @@ class VllmRunner:
returned_outputs.append((token_ids, texts)) returned_outputs.append((token_ids, texts))
return returned_outputs return returned_outputs
def encode(self, prompts: List[str]) -> List[List[float]]: def encode(
req_outputs = self.model.encode(prompts) self,
outputs = [] prompts: List[str],
for req_output in req_outputs: images: Optional[PromptImageInput] = None,
embedding = req_output.outputs.embedding videos: Optional[PromptVideoInput] = None,
outputs.append(embedding) audios: Optional[PromptAudioInput] = None,
return outputs ) -> List[List[float]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.encode(inputs)
return [req_output.outputs.embedding for req_output in req_outputs]
def __enter__(self): def __enter__(self):
return self return self

View File

@ -16,7 +16,8 @@ def check_embeddings_close(
for prompt_idx, (embeddings_0, embeddings_1) in enumerate( for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
zip(embeddings_0_lst, embeddings_1_lst)): zip(embeddings_0_lst, embeddings_1_lst)):
assert len(embeddings_0) == len(embeddings_1) assert len(embeddings_0) == len(embeddings_1), (
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
sim = F.cosine_similarity(torch.tensor(embeddings_0), sim = F.cosine_similarity(torch.tensor(embeddings_0),
torch.tensor(embeddings_1), torch.tensor(embeddings_1),

View File

@ -0,0 +1,135 @@
from typing import List, Type
import pytest
import torch.nn.functional as F
from transformers import AutoModelForVision2Seq
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
HF_TEXT_PROMPTS = [
# T -> X
llama3_template.format(
"The label of the object is stop sign\nSummary above sentence in one word: " # noqa: E501
),
# T -> X
llama3_template.format(
"cherry blossom\nSummary above sentence in one word: "),
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
# I -> X
"stop_sign":
llama3_template.format("<image>\nSummary above image in one word: "),
# I -> X
"cherry_blossom":
llama3_template.format("<image>\nSummary above image in one word: "),
})
MODELS = ["royokong/e5-v"]
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
input_texts: List[str],
input_images: PromptImageInput,
model: str,
*,
dtype: str,
) -> None:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
task="embedding",
dtype=dtype,
max_model_len=4096,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
# Patch the issue where image_token_id
# exceeds the maximum allowed vocab size
hf_model.model.resize_token_embeddings(
hf_model.model.language_model.vocab_size + 1)
all_inputs = hf_model.get_inputs(input_texts, images=input_images)
all_outputs = []
for inputs in all_inputs:
# Based on: https://huggingface.co/royokong/e5-v
outputs = hf_model.model(
**hf_model.wrap_device(inputs,
device=hf_model.model.device.type),
return_dict=True,
output_hidden_states=True,
)
pooled_output = F.normalize(outputs.hidden_states[-1][0, -1, :],
dim=-1)
all_outputs.append(pooled_output.tolist())
hf_outputs = all_outputs
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
model,
dtype=dtype,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
model,
dtype=dtype,
)

View File

@ -1,42 +1,53 @@
from typing import List, Type
import pytest import pytest
import torch.nn.functional as F import torch.nn.functional as F
from ....conftest import IMAGE_ASSETS from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close from ..utils import check_embeddings_close
HF_TEXT_PROMPTS = [
# T -> X
"Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501
# T -> X
"Retrieve an image of this caption: cherry blossom",
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
# T + I -> X
"stop_sign": "stop_sign":
"<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501 "<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501
# I -> X
"cherry_blossom": "cherry_blossom":
"<|image_1|> Represent the given image with the following question: What is in the image", # noqa: E501 "<|image_1|> Represent the given image for classification", # noqa: E501
}) })
MODELS = ["TIGER-Lab/VLM2Vec-Full"] MODELS = ["TIGER-Lab/VLM2Vec-Full"]
@pytest.mark.parametrize("model", MODELS) def _run_test(
@pytest.mark.parametrize("dtype", ["half"]) hf_runner: Type[HfRunner],
def test_models( vllm_runner: Type[VllmRunner],
hf_runner, input_texts: List[str],
vllm_runner, input_images: PromptImageInput,
example_prompts,
model: str, model: str,
*,
dtype: str, dtype: str,
) -> None: ) -> None:
# NOTE: take care of the order. run vLLM first, and then run HF. # NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model, with vllm_runner(model, task="embedding", dtype=dtype,
task="embedding",
max_model_len=4096,
max_num_seqs=2,
dtype=dtype,
enforce_eager=True) as vllm_model: enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.encode(input_texts, images=input_images)
with hf_runner(model, dtype=dtype) as hf_model: # use eager mode for hf runner, since phi3_v didn't work with flash_attn
all_inputs = hf_model.get_inputs(example_prompts) hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
all_inputs = hf_model.get_inputs(input_texts, images=input_images)
all_outputs = [] all_outputs = []
for inputs in all_inputs: for inputs in all_inputs:
@ -61,3 +72,53 @@ def test_models(
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
) )
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
model,
dtype=dtype,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
model,
dtype=dtype,
)

View File

@ -13,11 +13,13 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
@ -28,8 +30,8 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
merge_multimodal_embeddings) init_vllm_registered_model)
# Result in the max possible feature size (2x2 grid of 336x336px tiles) # Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
@ -312,6 +314,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@ -605,14 +611,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) inputs_embeds = embed_multimodal(
inputs_embeds = self.language_model.model.get_input_embeddings( input_ids,
input_ids) self.config.image_token_index,
self.language_model.model.get_input_embeddings,
inputs_embeds = merge_multimodal_embeddings( lambda _: self._process_image_input(image_input),
input_ids, inputs_embeds, vision_embeddings, )
self.config.image_token_index)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
@ -641,6 +645,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) loader.load_weights(weights)

View File

@ -467,8 +467,6 @@ def input_processor_for_phi3v(ctx: InputContext,
prompt_token_ids = inputs["prompt_token_ids"].copy() prompt_token_ids = inputs["prompt_token_ids"].copy()
print("prompt_token_ids (old)", prompt_token_ids)
# masked placeholder with image token id # masked placeholder with image token id
for idx in image_idx: for idx in image_idx:
candidates = _get_image_placeholder_token_id_candidates(model_config, candidates = _get_image_placeholder_token_id_candidates(model_config,

View File

@ -94,6 +94,7 @@ _EMBEDDING_MODELS = {
"MistralModel": ("llama", "LlamaEmbeddingModel"), "MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
} }

View File

@ -1,7 +1,7 @@
import itertools import itertools
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Protocol, Tuple, Union, overload) Optional, Protocol, Tuple, Union, overload)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -294,10 +294,11 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression(inner) for inner in embeddings) _embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings(input_ids: torch.Tensor, def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors, is_multimodal: torch.Tensor,
placeholder_token_id: int) -> torch.Tensor: multimodal_embeddings: NestedTensors,
) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in positions in ``inputs_embeds`` corresponding to placeholder tokens in
@ -306,8 +307,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
Note: Note:
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
mask = (input_ids == placeholder_token_id) num_expected_tokens = is_multimodal.sum().item()
num_expected_tokens = mask.sum().item()
assert isinstance(num_expected_tokens, int) assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings) flattened = _flatten_embeddings(multimodal_embeddings)
@ -317,10 +317,70 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
f"Attempted to assign {expr} = {flattened.shape[0]} " f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders") f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = flattened inputs_embeds[is_multimodal] = flattened
return inputs_embeds return inputs_embeds
def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor,
List[torch.Tensor]]],
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
merged_embeds[is_text] = text_embeds
return _merge_multimodal_embeddings(
merged_embeds,
is_multimodal,
multimodal_embeds,
)
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
)
class LayerFn(Protocol): class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module: def __call__(self, prefix: str) -> torch.nn.Module: