[Bugfix] Fix InternVL2 inference with various num_patches (#8375)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
520ca380ae
commit
e56bf27741
@ -331,6 +331,41 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["OpenGVLab/InternVL2-2B"])
|
||||||
|
@pytest.mark.parametrize("size_factors", [[0.5, 1.0]])
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_different_num_patches(hf_runner, vllm_runner, image_assets, model,
|
||||||
|
size_factors, dtype: str, max_tokens: int,
|
||||||
|
num_logprobs: int) -> None:
|
||||||
|
images = [asset.pil_image.resize((896, 896)) for asset in image_assets]
|
||||||
|
|
||||||
|
inputs_batching = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
|
inputs_multi_images = [
|
||||||
|
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||||
|
[[rescale_image_size(image, factor) for image in images]
|
||||||
|
for factor in size_factors])
|
||||||
|
]
|
||||||
|
for inputs in [inputs_batching, inputs_multi_images]:
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
inputs,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
mm_limit=2,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")])
|
"models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
|||||||
# Add an N dimension for number of images per prompt (currently 1).
|
# Add an N dimension for number of images per prompt (currently 1).
|
||||||
data = data.unsqueeze(0)
|
data = data.unsqueeze(0)
|
||||||
elif is_list_of(data, Image.Image):
|
elif is_list_of(data, Image.Image):
|
||||||
|
# we can't stack here because the images may have different num_patches
|
||||||
data = [
|
data = [
|
||||||
image_to_pixel_values(img,
|
image_to_pixel_values(img,
|
||||||
image_size,
|
image_size,
|
||||||
@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
|||||||
max_num,
|
max_num,
|
||||||
use_thumbnail=use_thumbnail) for img in data
|
use_thumbnail=use_thumbnail) for img in data
|
||||||
]
|
]
|
||||||
data = torch.stack(data)
|
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
@ -449,11 +449,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
|
|||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
# We need to flatten (B, N, P) to (B*N*P),
|
||||||
|
# so we call flatten_bn twice.
|
||||||
return InternVLImagePixelInputs(
|
return InternVLImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(
|
data=self._validate_pixel_values(
|
||||||
flatten_bn(pixel_values, concat=True).flatten(0, 1)),
|
flatten_bn(flatten_bn(pixel_values), concat=True)),
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user