[Bugfix] Fix feature size calculation for LLaVA-NeXT (#6982)

This commit is contained in:
Cyrus Leung 2024-07-31 23:46:17 +08:00 committed by GitHub
parent 2f4e108f75
commit daed30c4a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 98 additions and 50 deletions

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type, overload
import pytest import pytest
from transformers import AutoConfig, AutoTokenizer from transformers import AutoTokenizer
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
@ -50,6 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_test( def run_test(
hf_runner: Type[HfRunner], hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
@ -62,13 +63,55 @@ def run_test(
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
): ):
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
inputs_per_image = [( if size_factors is not None:
[prompt for _ in size_factors], inputs_per_image = [(
[rescale_image_size(image, factor) for factor in size_factors], [prompt for _ in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] [rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[prompt for _ in sizes],
[image.resize(size) for size in sizes],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
# max_model_len should be greater than image_feature_size # max_model_len should be greater than image_feature_size
with vllm_runner(model, with vllm_runner(model,
@ -150,15 +193,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
) )
@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144), @pytest.mark.parametrize("model", models)
(183, 488, 776)]) @pytest.mark.parametrize(
def test_image_feature_size(height_and_width_and_result): "sizes",
# Avoid initializing CUDA too early in distributed tests [[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
from vllm.model_executor.models.llava_next import ( )
get_llava_next_image_feature_size) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
height, width, result = height_and_width_and_result @pytest.mark.parametrize("num_logprobs", [5])
config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
assert get_llava_next_image_feature_size(config, dtype, max_tokens, num_logprobs) -> None:
input_height=height, run_test(
input_width=width) == result hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -169,7 +169,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
# process prompts # process prompts
prompt = llm_inputs["prompt"] prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = llm_inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model) tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1 # dim0 is batch_size, dim1 is subseq_size which will always be 1

View File

@ -20,7 +20,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_tokenizer from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
@ -43,7 +43,7 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
class InternVLImagePixelInputs(TypedDict): class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: BatchedTensors data: Union[torch.Tensor, List[torch.Tensor]]
""" """
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
@ -193,7 +193,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True) trust_remote_code=True)
prompt = llm_inputs["prompt"] prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None: if prompt is None:
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)

View File

@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
@ -43,7 +43,7 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: BatchedTensors data: Union[torch.Tensor, List[torch.Tensor]]
""" """
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
@ -62,31 +62,26 @@ class LlavaNextImagePixelInputs(TypedDict):
LlavaNextImageInputs = LlavaNextImagePixelInputs LlavaNextImageInputs = LlavaNextImagePixelInputs
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91 # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
# NOTE: new_height and new_width are further incremented to properly invert the
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
def _get_llava_next_num_unpadded_features( def _get_llava_next_num_unpadded_features(
height: int, original_height: int,
width: int, original_width: int,
npatches: int, npatches: int,
num_patch_height: int, num_patch_height: int,
num_patch_width: int, num_patch_width: int,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
current_height = npatches * num_patch_height current_height = npatches * num_patch_height
current_width = npatches * num_patch_width current_width = npatches * num_patch_width
current_height = torch.tensor(current_height).to("cuda")
current_width = torch.tensor(current_width).to("cuda")
aspect_ratio: float = width / height aspect_ratio = original_width / original_height
current_aspect_ratio: float = current_width / current_height current_aspect_ratio = current_width / current_height
if aspect_ratio > current_aspect_ratio: if aspect_ratio > current_aspect_ratio:
scale_factor = current_width / width new_height = (original_height * current_width) // original_width
new_height = int(height * scale_factor)
padding = (current_height - new_height) // 2 padding = (current_height - new_height) // 2
current_height -= padding * 2 current_height -= padding * 2
else: else:
scale_factor = current_height / height new_width = (original_width * current_height) // original_height
new_width = int(width * scale_factor)
padding = (current_width - new_width) // 2 padding = (current_width - new_width) // 2
current_width -= padding * 2 current_width -= padding * 2
@ -95,7 +90,7 @@ def _get_llava_next_num_unpadded_features(
return (unpadded_features, newline_features) return (unpadded_features, newline_features)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111 # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def get_llava_next_image_feature_size( def get_llava_next_image_feature_size(
hf_config: LlavaNextConfig, hf_config: LlavaNextConfig,
*, *,
@ -111,9 +106,7 @@ def get_llava_next_image_feature_size(
) )
base_feature_size = num_patches * num_patches base_feature_size = num_patches * num_patches
# Note: We follow the "wrong" width/height order num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size=(input_height, input_width), image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints, grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size, patch_size=vision_config.image_size,
@ -349,11 +342,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
if patch_embeddings.shape[0] > 1: if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:] other_patch_embeds = patch_embeddings[1:]
# Move to CPU to avoid floating-point errors
orig_height, orig_width = image_size.tolist()
# image_aspect_ratio == "anyres" # image_aspect_ratio == "anyres"
# Note: We follow the "wrong" width/height order num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# [ref: PR huggingface/transformers#31588] (orig_height, orig_width),
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
self.config.image_grid_pinpoints, self.config.image_grid_pinpoints,
self.config.vision_config.image_size, self.config.vision_config.image_size,
) )
@ -365,7 +359,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
.permute(4, 0, 2, 1, 3).contiguous() \ .permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3) .flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds, other_patch_embeds = unpad_image(other_patch_embeds,
image_size) (orig_height, orig_width))
other_patch_embeds = torch.cat(( other_patch_embeds = torch.cat((
other_patch_embeds, other_patch_embeds,
self.image_newline[:, None, None] \ self.image_newline[:, None, None] \
@ -398,7 +392,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_pixels( def _process_image_pixels(
self, self,
inputs: LlavaNextImagePixelInputs, inputs: LlavaNextImagePixelInputs,
) -> BatchedTensors: ) -> Union[torch.Tensor, List[torch.Tensor]]:
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = inputs["data"] pixel_values = inputs["data"]
@ -425,7 +419,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
] ]
def _process_image_input( def _process_image_input(
self, image_input: LlavaNextImageInputs) -> BatchedTensors: self,
image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
patch_embeddings = self._process_image_pixels(image_input) patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes") image_sizes = image_input.get("image_sizes")

View File

@ -36,7 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
@ -261,7 +261,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class Phi3VImagePixelInputs(TypedDict): class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: BatchedTensors data: Union[torch.Tensor, List[torch.Tensor]]
""" """
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`