[bug fix] Fix llava next feature size calculation. (#6339)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
This commit is contained in:
xwjiang2010 2024-07-11 10:21:10 -07:00 committed by GitHub
parent 52b7fcb35a
commit 1df43de9bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 9 deletions

View File

@ -1,8 +1,10 @@
from typing import List, Optional, Tuple
import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer
from vllm.model_executor.models.llava_next import (
get_llava_next_image_feature_size)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -120,3 +122,13 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144),
(183, 488, 776)])
def test_image_feature_size(height_and_width_and_result):
height, width, result = height_and_width_and_result
config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
assert get_llava_next_image_feature_size(config,
input_height=height,
input_width=width) == result

View File

@ -74,19 +74,21 @@ def _get_llava_next_num_unpadded_features(
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
current_height = torch.tensor(current_height).to("cuda")
current_width = torch.tensor(current_width).to("cuda")
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
if new_height % 2 == 1:
new_height += 1
current_height = new_height
scale_factor = current_width / width
new_height = int(height * scale_factor)
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = (width * current_height) // height
if new_width % 2 == 1:
new_width += 1
current_width = new_width
scale_factor = current_height / height
new_width = int(width * scale_factor)
padding = (current_width - new_width) // 2
current_width -= padding * 2
unpadded_features = current_height * current_width
newline_features = current_height