[Bugfix] Fix InternVL2 inference with various num_patches (#8375)

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py 2024-09-13 01:10:35 +08:00 committed by GitHub
parent 520ca380ae
commit e56bf27741
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 3 deletions

View File

@ -331,6 +331,41 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
)
@pytest.mark.parametrize("model", ["OpenGVLab/InternVL2-2B"])
@pytest.mark.parametrize("size_factors", [[0.5, 1.0]])
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@torch.inference_mode()
def test_different_num_patches(hf_runner, vllm_runner, image_assets, model,
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
images = [asset.pil_image.resize((896, 896)) for asset in image_assets]
inputs_batching = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
inputs_multi_images = [
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors])
]
for inputs in [inputs_batching, inputs_multi_images]:
run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=2,
tensor_parallel_size=1,
)
@pytest.mark.parametrize(
"models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")])
@pytest.mark.parametrize(

View File

@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because the images may have different num_patches
data = [
image_to_pixel_values(img,
image_size,
@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
max_num,
use_thumbnail=use_thumbnail) for img in data
]
data = torch.stack(data)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
@ -449,11 +449,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True).flatten(0, 1)),
flatten_bn(flatten_bn(pixel_values), concat=True)),
)
raise AssertionError("This line should be unreachable.")