[Bugfix] Fix feature size calculation for LLaVA-NeXT (#6982)
This commit is contained in:
parent
2f4e108f75
commit
daed30c4a9
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type, overload
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -50,6 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
@overload
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
@ -62,13 +63,55 @@ def run_test(
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
*,
|
||||
sizes: List[Tuple[int, int]],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
*,
|
||||
size_factors: Optional[List[float]] = None,
|
||||
sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
if size_factors is not None:
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
elif sizes is not None:
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in sizes],
|
||||
[image.resize(size) for size in sizes],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
else:
|
||||
raise ValueError("You must provide either `size_factors` or `sizes`")
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
@ -150,15 +193,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144),
|
||||
(183, 488, 776)])
|
||||
def test_image_feature_size(height_and_width_and_result):
|
||||
# Avoid initializing CUDA too early in distributed tests
|
||||
from vllm.model_executor.models.llava_next import (
|
||||
get_llava_next_image_feature_size)
|
||||
|
||||
height, width, result = height_and_width_and_result
|
||||
config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
||||
assert get_llava_next_image_feature_size(config,
|
||||
input_height=height,
|
||||
input_width=width) == result
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"sizes",
|
||||
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
|
||||
dtype, max_tokens, num_logprobs) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
sizes=sizes,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -169,7 +169,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
# process prompts
|
||||
prompt = llm_inputs["prompt"]
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
tokenizer = cached_get_tokenizer(model_config.model)
|
||||
# dim0 is batch_size, dim1 is subseq_size which will always be 1
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
@ -43,7 +43,7 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
|
||||
|
||||
class InternVLImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: BatchedTensors
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
@ -193,7 +193,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
|
||||
prompt = llm_inputs["prompt"]
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
@ -43,7 +43,7 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
||||
|
||||
class LlavaNextImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: BatchedTensors
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
@ -62,31 +62,26 @@ class LlavaNextImagePixelInputs(TypedDict):
|
||||
LlavaNextImageInputs = LlavaNextImagePixelInputs
|
||||
|
||||
|
||||
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
|
||||
# NOTE: new_height and new_width are further incremented to properly invert the
|
||||
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
||||
def _get_llava_next_num_unpadded_features(
|
||||
height: int,
|
||||
width: int,
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
current_height = torch.tensor(current_height).to("cuda")
|
||||
current_width = torch.tensor(current_width).to("cuda")
|
||||
|
||||
aspect_ratio: float = width / height
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / width
|
||||
new_height = int(height * scale_factor)
|
||||
new_height = (original_height * current_width) // original_width
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= padding * 2
|
||||
else:
|
||||
scale_factor = current_height / height
|
||||
new_width = int(width * scale_factor)
|
||||
new_width = (original_width * current_height) // original_height
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= padding * 2
|
||||
|
||||
@ -95,7 +90,7 @@ def _get_llava_next_num_unpadded_features(
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
|
||||
def get_llava_next_image_feature_size(
|
||||
hf_config: LlavaNextConfig,
|
||||
*,
|
||||
@ -111,9 +106,7 @@ def get_llava_next_image_feature_size(
|
||||
)
|
||||
base_feature_size = num_patches * num_patches
|
||||
|
||||
# Note: We follow the "wrong" width/height order
|
||||
# [ref: PR huggingface/transformers#31588]
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_size=(input_height, input_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=vision_config.image_size,
|
||||
@ -349,11 +342,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
if patch_embeddings.shape[0] > 1:
|
||||
other_patch_embeds = patch_embeddings[1:]
|
||||
|
||||
# Move to CPU to avoid floating-point errors
|
||||
orig_height, orig_width = image_size.tolist()
|
||||
|
||||
# image_aspect_ratio == "anyres"
|
||||
# Note: We follow the "wrong" width/height order
|
||||
# [ref: PR huggingface/transformers#31588]
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_size,
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
(orig_height, orig_width),
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
@ -365,7 +359,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
.permute(4, 0, 2, 1, 3).contiguous() \
|
||||
.flatten(1, 2).flatten(2, 3)
|
||||
other_patch_embeds = unpad_image(other_patch_embeds,
|
||||
image_size)
|
||||
(orig_height, orig_width))
|
||||
other_patch_embeds = torch.cat((
|
||||
other_patch_embeds,
|
||||
self.image_newline[:, None, None] \
|
||||
@ -398,7 +392,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: LlavaNextImagePixelInputs,
|
||||
) -> BatchedTensors:
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
@ -425,7 +419,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
]
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: LlavaNextImageInputs) -> BatchedTensors:
|
||||
self,
|
||||
image_input: LlavaNextImageInputs,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
patch_embeddings = self._process_image_pixels(image_input)
|
||||
|
||||
image_sizes = image_input.get("image_sizes")
|
||||
|
||||
@ -36,7 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
@ -261,7 +261,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: BatchedTensors
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user