[Bugfix] Fix feature size calculation for LLaVA-NeXT (#6982)
This commit is contained in:
parent
2f4e108f75
commit
daed30c4a9
@ -1,7 +1,7 @@
|
|||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type, overload
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
@ -50,6 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
|||||||
return hf_output_ids, hf_output_str, out_logprobs
|
return hf_output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
def run_test(
|
def run_test(
|
||||||
hf_runner: Type[HfRunner],
|
hf_runner: Type[HfRunner],
|
||||||
vllm_runner: Type[VllmRunner],
|
vllm_runner: Type[VllmRunner],
|
||||||
@ -62,13 +63,55 @@ def run_test(
|
|||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
distributed_executor_backend: Optional[str] = None,
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
sizes: List[Tuple[int, int]],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: Optional[List[float]] = None,
|
||||||
|
sizes: Optional[List[Tuple[int, int]]] = None,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
):
|
):
|
||||||
images = [asset.pil_image for asset in image_assets]
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
inputs_per_image = [(
|
if size_factors is not None:
|
||||||
[prompt for _ in size_factors],
|
inputs_per_image = [(
|
||||||
[rescale_image_size(image, factor) for factor in size_factors],
|
[prompt for _ in size_factors],
|
||||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
elif sizes is not None:
|
||||||
|
inputs_per_image = [(
|
||||||
|
[prompt for _ in sizes],
|
||||||
|
[image.resize(size) for size in sizes],
|
||||||
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
else:
|
||||||
|
raise ValueError("You must provide either `size_factors` or `sizes`")
|
||||||
|
|
||||||
# max_model_len should be greater than image_feature_size
|
# max_model_len should be greater than image_feature_size
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
@ -150,15 +193,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144),
|
@pytest.mark.parametrize("model", models)
|
||||||
(183, 488, 776)])
|
@pytest.mark.parametrize(
|
||||||
def test_image_feature_size(height_and_width_and_result):
|
"sizes",
|
||||||
# Avoid initializing CUDA too early in distributed tests
|
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
|
||||||
from vllm.model_executor.models.llava_next import (
|
)
|
||||||
get_llava_next_image_feature_size)
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
height, width, result = height_and_width_and_result
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
|
||||||
assert get_llava_next_image_feature_size(config,
|
dtype, max_tokens, num_logprobs) -> None:
|
||||||
input_height=height,
|
run_test(
|
||||||
input_width=width) == result
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
image_assets,
|
||||||
|
model,
|
||||||
|
sizes=sizes,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|||||||
@ -169,7 +169,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
|
|
||||||
# process prompts
|
# process prompts
|
||||||
prompt = llm_inputs["prompt"]
|
prompt = llm_inputs.get("prompt")
|
||||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||||
tokenizer = cached_get_tokenizer(model_config.model)
|
tokenizer = cached_get_tokenizer(model_config.model)
|
||||||
# dim0 is batch_size, dim1 is subseq_size which will always be 1
|
# dim0 is batch_size, dim1 is subseq_size which will always be 1
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.base import MultiModalInputs
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
from vllm.multimodal.image import cached_get_tokenizer
|
from vllm.multimodal.image import cached_get_tokenizer
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
@ -43,7 +43,7 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
|
|||||||
|
|
||||||
class InternVLImagePixelInputs(TypedDict):
|
class InternVLImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: BatchedTensors
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||||
|
|
||||||
@ -193,7 +193,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
|
|
||||||
prompt = llm_inputs["prompt"]
|
prompt = llm_inputs.get("prompt")
|
||||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = tokenizer.decode(prompt_token_ids)
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||||
from vllm.model_executor.models.llama import LlamaModel
|
from vllm.model_executor.models.llama import LlamaModel
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||||
@ -43,7 +43,7 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
|||||||
|
|
||||||
class LlavaNextImagePixelInputs(TypedDict):
|
class LlavaNextImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: BatchedTensors
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||||
|
|
||||||
@ -62,31 +62,26 @@ class LlavaNextImagePixelInputs(TypedDict):
|
|||||||
LlavaNextImageInputs = LlavaNextImagePixelInputs
|
LlavaNextImageInputs = LlavaNextImagePixelInputs
|
||||||
|
|
||||||
|
|
||||||
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
|
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
||||||
# NOTE: new_height and new_width are further incremented to properly invert the
|
|
||||||
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
|
|
||||||
def _get_llava_next_num_unpadded_features(
|
def _get_llava_next_num_unpadded_features(
|
||||||
height: int,
|
original_height: int,
|
||||||
width: int,
|
original_width: int,
|
||||||
npatches: int,
|
npatches: int,
|
||||||
num_patch_height: int,
|
num_patch_height: int,
|
||||||
num_patch_width: int,
|
num_patch_width: int,
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
current_height = npatches * num_patch_height
|
current_height = npatches * num_patch_height
|
||||||
current_width = npatches * num_patch_width
|
current_width = npatches * num_patch_width
|
||||||
current_height = torch.tensor(current_height).to("cuda")
|
|
||||||
current_width = torch.tensor(current_width).to("cuda")
|
|
||||||
|
|
||||||
aspect_ratio: float = width / height
|
aspect_ratio = original_width / original_height
|
||||||
current_aspect_ratio: float = current_width / current_height
|
current_aspect_ratio = current_width / current_height
|
||||||
|
|
||||||
if aspect_ratio > current_aspect_ratio:
|
if aspect_ratio > current_aspect_ratio:
|
||||||
scale_factor = current_width / width
|
new_height = (original_height * current_width) // original_width
|
||||||
new_height = int(height * scale_factor)
|
|
||||||
padding = (current_height - new_height) // 2
|
padding = (current_height - new_height) // 2
|
||||||
current_height -= padding * 2
|
current_height -= padding * 2
|
||||||
else:
|
else:
|
||||||
scale_factor = current_height / height
|
new_width = (original_width * current_height) // original_height
|
||||||
new_width = int(width * scale_factor)
|
|
||||||
padding = (current_width - new_width) // 2
|
padding = (current_width - new_width) // 2
|
||||||
current_width -= padding * 2
|
current_width -= padding * 2
|
||||||
|
|
||||||
@ -95,7 +90,7 @@ def _get_llava_next_num_unpadded_features(
|
|||||||
return (unpadded_features, newline_features)
|
return (unpadded_features, newline_features)
|
||||||
|
|
||||||
|
|
||||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
|
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
|
||||||
def get_llava_next_image_feature_size(
|
def get_llava_next_image_feature_size(
|
||||||
hf_config: LlavaNextConfig,
|
hf_config: LlavaNextConfig,
|
||||||
*,
|
*,
|
||||||
@ -111,9 +106,7 @@ def get_llava_next_image_feature_size(
|
|||||||
)
|
)
|
||||||
base_feature_size = num_patches * num_patches
|
base_feature_size = num_patches * num_patches
|
||||||
|
|
||||||
# Note: We follow the "wrong" width/height order
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
# [ref: PR huggingface/transformers#31588]
|
|
||||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
||||||
image_size=(input_height, input_width),
|
image_size=(input_height, input_width),
|
||||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||||
patch_size=vision_config.image_size,
|
patch_size=vision_config.image_size,
|
||||||
@ -349,11 +342,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
if patch_embeddings.shape[0] > 1:
|
if patch_embeddings.shape[0] > 1:
|
||||||
other_patch_embeds = patch_embeddings[1:]
|
other_patch_embeds = patch_embeddings[1:]
|
||||||
|
|
||||||
|
# Move to CPU to avoid floating-point errors
|
||||||
|
orig_height, orig_width = image_size.tolist()
|
||||||
|
|
||||||
# image_aspect_ratio == "anyres"
|
# image_aspect_ratio == "anyres"
|
||||||
# Note: We follow the "wrong" width/height order
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
# [ref: PR huggingface/transformers#31588]
|
(orig_height, orig_width),
|
||||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
||||||
image_size,
|
|
||||||
self.config.image_grid_pinpoints,
|
self.config.image_grid_pinpoints,
|
||||||
self.config.vision_config.image_size,
|
self.config.vision_config.image_size,
|
||||||
)
|
)
|
||||||
@ -365,7 +359,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
.permute(4, 0, 2, 1, 3).contiguous() \
|
.permute(4, 0, 2, 1, 3).contiguous() \
|
||||||
.flatten(1, 2).flatten(2, 3)
|
.flatten(1, 2).flatten(2, 3)
|
||||||
other_patch_embeds = unpad_image(other_patch_embeds,
|
other_patch_embeds = unpad_image(other_patch_embeds,
|
||||||
image_size)
|
(orig_height, orig_width))
|
||||||
other_patch_embeds = torch.cat((
|
other_patch_embeds = torch.cat((
|
||||||
other_patch_embeds,
|
other_patch_embeds,
|
||||||
self.image_newline[:, None, None] \
|
self.image_newline[:, None, None] \
|
||||||
@ -398,7 +392,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
def _process_image_pixels(
|
def _process_image_pixels(
|
||||||
self,
|
self,
|
||||||
inputs: LlavaNextImagePixelInputs,
|
inputs: LlavaNextImagePixelInputs,
|
||||||
) -> BatchedTensors:
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
pixel_values = inputs["data"]
|
pixel_values = inputs["data"]
|
||||||
@ -425,7 +419,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self, image_input: LlavaNextImageInputs) -> BatchedTensors:
|
self,
|
||||||
|
image_input: LlavaNextImageInputs,
|
||||||
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
patch_embeddings = self._process_image_pixels(image_input)
|
patch_embeddings = self._process_image_pixels(image_input)
|
||||||
|
|
||||||
image_sizes = image_input.get("image_sizes")
|
image_sizes = image_input.get("image_sizes")
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||||
from vllm.model_executor.models.llama import LlamaModel
|
from vllm.model_executor.models.llama import LlamaModel
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.image import cached_get_tokenizer
|
from vllm.multimodal.image import cached_get_tokenizer
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
@ -261,7 +261,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
|||||||
|
|
||||||
class Phi3VImagePixelInputs(TypedDict):
|
class Phi3VImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: BatchedTensors
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user