[Model] Compute Llava Next Max Tokens / Dummy Data From Gridpoints (#9650)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
c866e0079d
commit
722d46edb9
@ -3,12 +3,13 @@ from typing import List, Optional, Tuple, Type, overload
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.inputs import InputContext
|
||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||||
_ImageAssets)
|
_ImageAssets)
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import build_model_context, check_logprobs_close
|
||||||
|
|
||||||
_LIMIT_IMAGE_PER_PROMPT = 4
|
_LIMIT_IMAGE_PER_PROMPT = 4
|
||||||
|
|
||||||
@ -22,6 +23,19 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
models = ["llava-hf/llava-v1.6-mistral-7b-hf"]
|
models = ["llava-hf/llava-v1.6-mistral-7b-hf"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def get_max_llava_next_image_tokens():
|
||||||
|
from vllm.model_executor.models.llava_next import (
|
||||||
|
get_max_llava_next_image_tokens)
|
||||||
|
return get_max_llava_next_image_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def dummy_data_for_llava_next():
|
||||||
|
from vllm.model_executor.models.llava_next import dummy_data_for_llava_next
|
||||||
|
return dummy_data_for_llava_next
|
||||||
|
|
||||||
|
|
||||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
Optional[SampleLogprobs]],
|
Optional[SampleLogprobs]],
|
||||||
model: str):
|
model: str):
|
||||||
@ -281,3 +295,53 @@ def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
|
|||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("gridpoints,expected_max_tokens", [
|
||||||
|
([[336, 336]], 1176),
|
||||||
|
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], 2928),
|
||||||
|
])
|
||||||
|
def test_get_max_llava_next_image_tokens(gridpoints, expected_max_tokens,
|
||||||
|
get_max_llava_next_image_tokens):
|
||||||
|
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")
|
||||||
|
|
||||||
|
# Update the config image_grid_pinpoints
|
||||||
|
# and calculate the resulting max tokens
|
||||||
|
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
|
||||||
|
|
||||||
|
actual_max_tokens = get_max_llava_next_image_tokens(
|
||||||
|
InputContext(ctx.model_config))
|
||||||
|
|
||||||
|
assert expected_max_tokens == actual_max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"gridpoints,expected_size",
|
||||||
|
[
|
||||||
|
# One point; it has to be the largest
|
||||||
|
([[336, 336]], (336, 336)),
|
||||||
|
# Default for most llava next models; the 2x2 tile is the largest
|
||||||
|
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]],
|
||||||
|
(672, 672)),
|
||||||
|
# If two rectangular gridpoints are the same, the more vertical
|
||||||
|
# one has the higher feature count due to newline features
|
||||||
|
([[336, 672], [672, 336]], (672, 336))
|
||||||
|
])
|
||||||
|
def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
|
||||||
|
gridpoints, expected_size):
|
||||||
|
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")
|
||||||
|
|
||||||
|
# Update the config image_grid_pinpoints
|
||||||
|
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
|
||||||
|
seq_len = 5000 # bigger than the max feature size for any image
|
||||||
|
|
||||||
|
seq_data, mm_data = dummy_data_for_llava_next(
|
||||||
|
ctx,
|
||||||
|
seq_len=seq_len,
|
||||||
|
mm_counts={"image": 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
# The dummy data dims should match the gridpoint with the biggest feat size
|
||||||
|
assert mm_data["image"].height == expected_size[0]
|
||||||
|
assert mm_data["image"].width == expected_size[1]
|
||||||
|
assert len(seq_data.get_token_ids()) >= seq_len
|
||||||
|
|||||||
@ -33,9 +33,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
|||||||
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||||
init_vllm_registered_model)
|
init_vllm_registered_model)
|
||||||
|
|
||||||
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
|
||||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextImagePixelInputs(TypedDict):
|
class LlavaNextImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
@ -149,11 +146,28 @@ def get_llava_next_image_feature_size(
|
|||||||
|
|
||||||
|
|
||||||
def get_max_llava_next_image_tokens(ctx: InputContext):
|
def get_max_llava_next_image_tokens(ctx: InputContext):
|
||||||
return get_llava_next_image_feature_size(
|
"""Compute the max feature size for all possible image grid pinpoints."""
|
||||||
ctx.get_hf_config(LlavaNextConfig),
|
return _get_pinpoint_with_largest_features(ctx)[0]
|
||||||
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
|
||||||
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
|
||||||
|
def _get_pinpoint_with_largest_features(
|
||||||
|
ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
|
||||||
|
"""Get the grid pinpoint with the largest features & its feature size."""
|
||||||
|
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||||
|
largest_feature_size = 0
|
||||||
|
largest_feature_pinpoint = None
|
||||||
|
for (height, width) in hf_config.image_grid_pinpoints:
|
||||||
|
feat_size = get_llava_next_image_feature_size(
|
||||||
|
hf_config,
|
||||||
|
input_height=height,
|
||||||
|
input_width=width,
|
||||||
)
|
)
|
||||||
|
if feat_size > largest_feature_size:
|
||||||
|
largest_feature_size = feat_size
|
||||||
|
largest_feature_pinpoint = (height, width)
|
||||||
|
if not largest_feature_size or largest_feature_pinpoint is None:
|
||||||
|
raise ValueError("Cannot have a largest feature size of 0!")
|
||||||
|
return largest_feature_size, largest_feature_pinpoint
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
||||||
@ -162,7 +176,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
|||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
num_images = mm_counts["image"]
|
num_images = mm_counts["image"]
|
||||||
|
|
||||||
image_feature_size = get_max_llava_next_image_tokens(ctx)
|
image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
|
||||||
|
max_feat_height, max_feat_width = pinpoint
|
||||||
|
|
||||||
if isinstance(vision_config, CLIPVisionConfig):
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
seq_data = dummy_seq_data_for_clip(
|
seq_data = dummy_seq_data_for_clip(
|
||||||
@ -176,8 +191,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
|||||||
mm_data = dummy_image_for_clip(
|
mm_data = dummy_image_for_clip(
|
||||||
vision_config,
|
vision_config,
|
||||||
num_images,
|
num_images,
|
||||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
image_width_override=max_feat_width,
|
||||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
image_height_override=max_feat_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
return seq_data, mm_data
|
return seq_data, mm_data
|
||||||
@ -193,8 +208,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
|||||||
mm_data = dummy_image_for_siglip(
|
mm_data = dummy_image_for_siglip(
|
||||||
vision_config,
|
vision_config,
|
||||||
num_images,
|
num_images,
|
||||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
image_width_override=max_feat_width,
|
||||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
image_height_override=max_feat_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
return seq_data, mm_data
|
return seq_data, mm_data
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user