[Model] Compute Llava Next Max Tokens / Dummy Data From Gridpoints (#9650)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks 2024-10-24 11:42:24 -06:00 committed by GitHub
parent c866e0079d
commit 722d46edb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 93 additions and 14 deletions

View File

@ -3,12 +3,13 @@ from typing import List, Optional, Tuple, Type, overload
import pytest import pytest
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from vllm.inputs import InputContext
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets) _ImageAssets)
from ...utils import check_logprobs_close from ...utils import build_model_context, check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 4 _LIMIT_IMAGE_PER_PROMPT = 4
@ -22,6 +23,19 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models = ["llava-hf/llava-v1.6-mistral-7b-hf"] models = ["llava-hf/llava-v1.6-mistral-7b-hf"]
@pytest.fixture()
def get_max_llava_next_image_tokens():
from vllm.model_executor.models.llava_next import (
get_max_llava_next_image_tokens)
return get_max_llava_next_image_tokens
@pytest.fixture()
def dummy_data_for_llava_next():
from vllm.model_executor.models.llava_next import dummy_data_for_llava_next
return dummy_data_for_llava_next
def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]], Optional[SampleLogprobs]],
model: str): model: str):
@ -281,3 +295,53 @@ def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
@pytest.mark.parametrize("gridpoints,expected_max_tokens", [
([[336, 336]], 1176),
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], 2928),
])
def test_get_max_llava_next_image_tokens(gridpoints, expected_max_tokens,
get_max_llava_next_image_tokens):
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")
# Update the config image_grid_pinpoints
# and calculate the resulting max tokens
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
actual_max_tokens = get_max_llava_next_image_tokens(
InputContext(ctx.model_config))
assert expected_max_tokens == actual_max_tokens
@pytest.mark.parametrize(
"gridpoints,expected_size",
[
# One point; it has to be the largest
([[336, 336]], (336, 336)),
# Default for most llava next models; the 2x2 tile is the largest
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]],
(672, 672)),
# If two rectangular gridpoints are the same, the more vertical
# one has the higher feature count due to newline features
([[336, 672], [672, 336]], (672, 336))
])
def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
gridpoints, expected_size):
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")
# Update the config image_grid_pinpoints
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
seq_len = 5000 # bigger than the max feature size for any image
seq_data, mm_data = dummy_data_for_llava_next(
ctx,
seq_len=seq_len,
mm_counts={"image": 1},
)
# The dummy data dims should match the gridpoint with the biggest feat size
assert mm_data["image"].height == expected_size[0]
assert mm_data["image"].width == expected_size[1]
assert len(seq_data.get_token_ids()) >= seq_len

View File

@ -33,9 +33,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model) init_vllm_registered_model)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
@ -149,11 +146,28 @@ def get_llava_next_image_feature_size(
def get_max_llava_next_image_tokens(ctx: InputContext): def get_max_llava_next_image_tokens(ctx: InputContext):
return get_llava_next_image_feature_size( """Compute the max feature size for all possible image grid pinpoints."""
ctx.get_hf_config(LlavaNextConfig), return _get_pinpoint_with_largest_features(ctx)[0]
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
def _get_pinpoint_with_largest_features(
ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
"""Get the grid pinpoint with the largest features & its feature size."""
hf_config = ctx.get_hf_config(LlavaNextConfig)
largest_feature_size = 0
largest_feature_pinpoint = None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
) )
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = (height, width)
if not largest_feature_size or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_size, largest_feature_pinpoint
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
@ -162,7 +176,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"] num_images = mm_counts["image"]
image_feature_size = get_max_llava_next_image_tokens(ctx) image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
max_feat_height, max_feat_width = pinpoint
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
@ -176,8 +191,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_data = dummy_image_for_clip( mm_data = dummy_image_for_clip(
vision_config, vision_config,
num_images, num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width_override=max_feat_width,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height_override=max_feat_height,
) )
return seq_data, mm_data return seq_data, mm_data
@ -193,8 +208,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_data = dummy_image_for_siglip( mm_data = dummy_image_for_siglip(
vision_config, vision_config,
num_images, num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width_override=max_feat_width,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height_override=max_feat_height,
) )
return seq_data, mm_data return seq_data, mm_data