[Model] Initial support for LLaVA-NeXT (#4199)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
0bfa1c4f13
commit
6b29d6fe70
@ -89,7 +89,11 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- ✅︎
|
||||
* - :code:`LlavaForConditionalGeneration`
|
||||
- LLaVA-1.5
|
||||
- :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc.
|
||||
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
|
||||
-
|
||||
* - :code:`LlavaNextForConditionalGeneration`
|
||||
- LLaVA-NeXT
|
||||
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
|
||||
-
|
||||
* - :code:`MiniCPMForCausalLM`
|
||||
- MiniCPM
|
||||
|
||||
@ -39,8 +39,6 @@ def iter_llava_configs(model_name: str):
|
||||
|
||||
model_and_vl_config = [
|
||||
*iter_llava_configs("llava-hf/llava-1.5-7b-hf"),
|
||||
# Not enough memory
|
||||
# *iter_llava_configs("llava-hf/llava-1.5-13b-hf"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
123
tests/models/test_llava_next.py
Normal file
123
tests/models/test_llava_next.py
Normal file
@ -0,0 +1,123 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
|
||||
from ..conftest import IMAGE_FILES
|
||||
|
||||
pytestmark = pytest.mark.llava
|
||||
|
||||
_PREFACE = (
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's "
|
||||
"questions.")
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = [
|
||||
f"{_PREFACE} <image>\nUSER: What's the content of the image? ASSISTANT:",
|
||||
f"{_PREFACE} <image>\nUSER: What is the season? ASSISTANT:",
|
||||
]
|
||||
|
||||
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
|
||||
|
||||
|
||||
def iter_llava_next_configs(model_name: str):
|
||||
image_hw_to_feature_size = {
|
||||
(336, 336): 1176,
|
||||
(672, 672): 2928,
|
||||
(1344, 336): 1944,
|
||||
(336, 1344): 1890,
|
||||
}
|
||||
|
||||
for (h, w), f in image_hw_to_feature_size.items():
|
||||
for input_type, input_shape in [
|
||||
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
|
||||
]:
|
||||
yield (model_name,
|
||||
VisionLanguageConfig(image_input_type=input_type,
|
||||
image_feature_size=f,
|
||||
image_token_id=32000,
|
||||
image_input_shape=input_shape,
|
||||
image_processor=model_name,
|
||||
image_processor_revision=None))
|
||||
|
||||
|
||||
model_and_vl_config = [
|
||||
*iter_llava_next_configs("llava-hf/llava-v1.6-vicuna-7b-hf"),
|
||||
]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
vlm_config: VisionLanguageConfig, model_id: str):
|
||||
"""Sanitize vllm output to be comparable with hf output.
|
||||
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
input_ids, output_str = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
|
||||
hf_input_ids = [
|
||||
input_id for idx, input_id in enumerate(input_ids)
|
||||
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
|
||||
]
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, " ")
|
||||
|
||||
return hf_input_ids, hf_output_str
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Inconsistent image processor being used due to lack "
|
||||
"of support for dynamic image token replacement")
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
|
||||
model_and_config, dtype: str, max_tokens: int) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalData objects and corresponding
|
||||
vision language config as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
|
||||
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
|
||||
max_tokens,
|
||||
images=hf_images)
|
||||
|
||||
vllm_image_prompts = [
|
||||
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
|
||||
for p in HF_IMAGE_PROMPTS
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model_id,
|
||||
dtype=dtype,
|
||||
# should be greater than image_feature_size
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict(),
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||
max_tokens,
|
||||
images=vllm_images)
|
||||
|
||||
for i in range(len(HF_IMAGE_PROMPTS)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
|
||||
vllm_outputs[i], vlm_config, model_id)
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import CLIPImageProcessor
|
||||
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -12,7 +12,7 @@ from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
def test_clip_image_processor(hf_images, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 33
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 560
|
||||
|
||||
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
|
||||
assert isinstance(hf_processor, CLIPImageProcessor)
|
||||
@ -55,10 +55,61 @@ def test_clip_image_processor(hf_images, dtype):
|
||||
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Inconsistent image processor being used due to lack "
|
||||
"of support for dynamic image token replacement")
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
def test_llava_next_image_processor(hf_images, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-34b-hf"
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 560
|
||||
|
||||
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
|
||||
assert isinstance(hf_processor, LlavaNextImageProcessor)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
)
|
||||
vlm_config = VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_token_id=64000,
|
||||
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
|
||||
image_feature_size=2928,
|
||||
image_processor=MODEL_NAME,
|
||||
image_processor_revision=None,
|
||||
)
|
||||
|
||||
for image in hf_images:
|
||||
hf_result = hf_processor.preprocess(
|
||||
image,
|
||||
return_tensors="pt",
|
||||
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
|
||||
vllm_result = MULTIMODAL_REGISTRY.process_input(
|
||||
ImagePixelData(image),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
assert hf_result.keys() == vllm_result.keys()
|
||||
for key, hf_tensor in hf_result.items():
|
||||
hf_arr: np.ndarray = hf_tensor.numpy()
|
||||
vllm_arr: np.ndarray = vllm_result[key].numpy()
|
||||
|
||||
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
|
||||
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Example image pixels were not processed using HuggingFace")
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 33
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 560
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
@ -95,7 +146,4 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
|
||||
tensor_arr: np.ndarray = tensor_result[key].numpy()
|
||||
|
||||
assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}"
|
||||
|
||||
# The examples in PR#3042 have slightly different preprocessing from
|
||||
# HuggingFace's LlavaProcessor, causing the test to fail.
|
||||
# assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"
|
||||
assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"
|
||||
|
||||
@ -33,6 +33,8 @@ _GENERATION_MODELS = {
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"LlavaForConditionalGeneration":
|
||||
("llava", "LlavaForConditionalGeneration"),
|
||||
"LlavaNextForConditionalGeneration":
|
||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn as nn
|
||||
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
|
||||
# transformers' impl.
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
@ -51,10 +51,10 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: torch.Tensor,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: torch.Tensor,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = (input_ids == image_token_id)
|
||||
|
||||
@ -151,7 +151,8 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values")
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@ -166,7 +167,8 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
return None
|
||||
|
||||
if not isinstance(image_features, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image features")
|
||||
raise ValueError("Incorrect type of image features. "
|
||||
f"Got type: {type(image_features)}")
|
||||
|
||||
return LlavaImageFeatureInputs(
|
||||
type="image_features",
|
||||
@ -268,7 +270,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = _merge_vision_embeddings(
|
||||
inputs_embeds = merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_language_config.image_token_id)
|
||||
|
||||
|
||||
445
vllm/model_executor/models/llava_next.py
Normal file
445
vllm/model_executor/models/llava_next.py
Normal file
@ -0,0 +1,445 @@
|
||||
from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
|
||||
# transformers' impl.
|
||||
from transformers import CLIPVisionModel, LlavaNextConfig
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
||||
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput, SequenceData
|
||||
|
||||
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
||||
from .vlm_base import VisionLanguageModelBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
"language_model.model": "language_model",
|
||||
}
|
||||
|
||||
|
||||
class LlavaNextImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
|
||||
|
||||
image_sizes: NotRequired[torch.Tensor]
|
||||
"""Shape: (batch_size, 2)"""
|
||||
|
||||
|
||||
class LlavaNextImageFeatureInputs(TypedDict):
|
||||
type: Literal["image_features"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""
|
||||
|
||||
image_sizes: NotRequired[torch.Tensor]
|
||||
"""Shape: (batch_size, 2)"""
|
||||
|
||||
|
||||
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||
LlavaNextImageFeatureInputs]
|
||||
|
||||
|
||||
def _get_dummy_image_data(
|
||||
seq_len: int,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Tuple[SequenceData, MultiModalData]:
|
||||
seq_data, fake_mm_data = get_dummy_image_data(seq_len, model_config,
|
||||
vlm_config)
|
||||
|
||||
config_input_type = vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if config_input_type == ImageInputType.PIXEL_VALUES:
|
||||
_, c, h, w = vlm_config.image_input_shape
|
||||
mode = {1: "L", 3: "RGB"}[c]
|
||||
fake_mm_data = ImagePixelData(Image.new(mode, (w, h), color=0))
|
||||
|
||||
return seq_data, fake_mm_data
|
||||
|
||||
|
||||
def _image_pixel_processor(
|
||||
data: ImagePixelData,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
image = data.image
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
pixel_values = image.to(model_config.dtype)
|
||||
batch_size, _, _, h, w = pixel_values.shape
|
||||
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
|
||||
|
||||
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
|
||||
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = vlm_config.image_input_shape
|
||||
if (w, h) != (image.width, image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
"Resizing input image to (%d, %d).", w, h)
|
||||
|
||||
data.image = image.resize((w, h))
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
|
||||
._default_input_processor(data, model_config, vlm_config)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
|
||||
class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
"""
|
||||
Args to `forward()`:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values: For PIXEL_VALUES, expects a batch with shape
|
||||
[1, num_patches, 3, 336, 336].
|
||||
image_features: For IMAGE_FEATURES, expects a batch with shape
|
||||
[1, num_patches, 1176, 1024].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaNextConfig,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__(vision_language_config)
|
||||
|
||||
# Update the type annotation from that of its superclass
|
||||
self.config = config
|
||||
|
||||
if self.vision_language_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||
else:
|
||||
raise TypeError("Image features are not supported by LLaVA-NeXT")
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act)
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.language_model = LlamaModel(config.text_config, cache_config,
|
||||
quant_config)
|
||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.text_config.hidden_size,
|
||||
org_num_embeddings=self.language_model.org_vocab_size)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
|
||||
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
|
||||
_, num_channels, _, _ = self.vision_language_config.image_input_shape
|
||||
|
||||
# Note that this is different from that of vLLM vision_language_config
|
||||
# since the image is resized by the HuggingFace preprocessor
|
||||
height = width = self.config.vision_config.image_size
|
||||
|
||||
if list(data.shape[2:]) != [num_channels, height, width]:
|
||||
raise ValueError(
|
||||
f"The expected image tensor shape is batch dimension plus "
|
||||
f"num_patches plus {[num_channels, height, width]}. "
|
||||
f"You supplied {data.shape}. "
|
||||
f"If you are using vLLM's entrypoint, make sure your "
|
||||
f"supplied image input is consistent with "
|
||||
f"image_input_shape in engine args.")
|
||||
|
||||
return data
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != [2]:
|
||||
raise ValueError(
|
||||
f"The expected image sizes shape is batch dimension plus "
|
||||
f"{[2]}. You supplied {data.shape}.")
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_features = kwargs.pop("image_features", None)
|
||||
|
||||
expected_input_type = self.vision_language_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||
if image_features is not None:
|
||||
raise ValueError(
|
||||
"Expected pixel values but got image features")
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_pixels(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes),
|
||||
)
|
||||
|
||||
assert expected_input_type != ImageInputType.IMAGE_FEATURES, (
|
||||
"Failed to validate this at initialization time")
|
||||
|
||||
return None
|
||||
|
||||
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
|
||||
patch_embeddings: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
||||
if strategy == "flat":
|
||||
return patch_embeddings.flatten(0, 1)
|
||||
|
||||
if strategy.startswith("spatial"):
|
||||
orig_width, orig_height = image_size
|
||||
height = width = self.config.vision_config.image_size \
|
||||
// self.config.vision_config.patch_size
|
||||
|
||||
base_patch_embeds = patch_embeddings[0]
|
||||
if height * width != base_patch_embeds.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the "
|
||||
"image size.")
|
||||
|
||||
if patch_embeddings.shape[0] > 1:
|
||||
other_patch_embeds = patch_embeddings[1:]
|
||||
|
||||
# image_aspect_ratio == "anyres"
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
(orig_width, orig_height),
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.view(num_patch_width, num_patch_height, height, width, -1)
|
||||
|
||||
if "unpad" in strategy:
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.permute(4, 0, 2, 1, 3).contiguous() \
|
||||
.flatten(1, 2).flatten(2, 3)
|
||||
other_patch_embeds = unpad_image(other_patch_embeds,
|
||||
image_size)
|
||||
other_patch_embeds = torch.cat((
|
||||
other_patch_embeds,
|
||||
self.image_newline[:, None, None] \
|
||||
.expand(*other_patch_embeds.shape[:-1], 1) \
|
||||
.to(other_patch_embeds.device),
|
||||
), dim=-1)
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.flatten(1, 2).transpose(0, 1)
|
||||
else:
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.permute(0, 2, 1, 3, 4).contiguous() \
|
||||
.flatten(0, 3)
|
||||
|
||||
merged_patch_embeddings = torch.cat(
|
||||
(base_patch_embeds, other_patch_embeds), dim=0)
|
||||
else:
|
||||
if "unpad" in strategy:
|
||||
merged_patch_embeddings = torch.cat(
|
||||
(base_patch_embeds,
|
||||
self.image_newline[None] \
|
||||
.to(base_patch_embeds.device)
|
||||
), dim=0)
|
||||
else:
|
||||
merged_patch_embeddings = base_patch_embeds
|
||||
|
||||
return merged_patch_embeddings
|
||||
|
||||
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
|
||||
|
||||
def _process_image_pixels(
|
||||
self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
||||
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
|
||||
return stacked_image_features.view(b, num_patches,
|
||||
*stacked_image_features.shape[-2:])
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
|
||||
if image_input["type"] == "pixel_values":
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
else:
|
||||
image_features = image_input["data"]
|
||||
|
||||
patch_embeddings = self.multi_modal_projector(image_features)
|
||||
|
||||
image_sizes = image_input.get("image_sizes")
|
||||
if image_sizes is None:
|
||||
batch_size = image_input["data"].shape[0]
|
||||
vision_config = self.config.vision_config
|
||||
default_width = default_height = vision_config.image_size
|
||||
image_sizes = torch.as_tensor([[default_width, default_height]
|
||||
for _ in range(batch_size)])
|
||||
|
||||
merged_patch_embeddings = [
|
||||
self._merge_image_patch_embeddings(image_sizes[i],
|
||||
patch_features,
|
||||
strategy="spatial_unpad")
|
||||
for i, patch_features in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
return torch.stack(merged_patch_embeddings, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
"""Run forward pass for Llava 1.5.
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
positions of the to-be-inserted image embeddings.
|
||||
Concretely, consider a text prompt:
|
||||
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
|
||||
Tokenizer outputs:
|
||||
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
|
||||
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
|
||||
The to-be-inserted image has a size of 576 (24 * 24) along the context
|
||||
length dimension.
|
||||
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
|
||||
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
|
||||
9047, 13566, 29901].
|
||||
There will be 576 `32000` in the `input_ids`.
|
||||
(32000 is the token id for `<image>`.)
|
||||
|
||||
This way, the `positions` and `attn_metadata` are consistent
|
||||
with the `input_ids`.
|
||||
|
||||
The model takes two types of image inputs:
|
||||
PIXEL_VALUES and IMAGE_FEATURES.
|
||||
The following shows how each maps to huggingface implementation.
|
||||
PIXEL_VALUES:
|
||||
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
|
||||
IMAGE_FEATURES:
|
||||
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
|
||||
before going through the multi modal projector.
|
||||
|
||||
Args:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values: For PIXEL_VALUES, expects a batch with shape
|
||||
[1, 3, 336, 336].
|
||||
image_features: For IMAGE_FEATURES, expects a batch with shape
|
||||
[1, 576, 1024].
|
||||
"""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_language_config.image_token_id)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# only doing this for language model part for now.
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
use_default_weight_loading = False
|
||||
if "vision" in name:
|
||||
if self.vision_tower is not None:
|
||||
# We only do sharding for language model and
|
||||
# not vision model for now.
|
||||
use_default_weight_loading = True
|
||||
else:
|
||||
for (param_name, weight_name,
|
||||
shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
use_default_weight_loading = True
|
||||
if use_default_weight_loading:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
Loading…
Reference in New Issue
Block a user