diff --git a/tests/conftest.py b/tests/conftest.py index 6e033e76..08a2c8fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,9 @@ from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig from vllm.connections import global_http_connection from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) + destroy_model_parallel, + init_distributed_environment, + initialize_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger @@ -90,6 +92,21 @@ def init_test_http_connection(): global_http_connection.reuse_client = False +@pytest.fixture +def dist_init(): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield + cleanup() + + def cleanup(): destroy_model_parallel() destroy_distributed_environment() diff --git a/tests/models/test_intern_vit.py b/tests/models/test_intern_vit.py new file mode 100644 index 00000000..e980446f --- /dev/null +++ b/tests/models/test_intern_vit.py @@ -0,0 +1,80 @@ +from typing import Optional + +import pytest +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from transformers import AutoConfig, AutoModel, CLIPImageProcessor + +from vllm.model_executor.models.intern_vit import InternVisionModel + +from ..conftest import _ImageAssets, cleanup + +pytestmark = pytest.mark.vlm + +# we use snapshot_download to prevent conflicts between +# dynamic_module and trust_remote_code for hf_runner +DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] +models = [ + snapshot_download("OpenGVLab/InternViT-300M-448px", + allow_patterns=DOWNLOAD_PATTERN), + snapshot_download("OpenGVLab/InternViT-6B-448px-V1-5", + allow_patterns=DOWNLOAD_PATTERN), +] + + +def run_intern_vit_test( + image_assets: _ImageAssets, + model: str, + *, + dtype: str, + distributed_executor_backend: Optional[str] = None, +): + img_processor = CLIPImageProcessor.from_pretrained(model) + images = [asset.pil_image for asset in image_assets] + pixel_values = [ + img_processor(images, return_tensors='pt').pixel_values.to(dtype) + for images in images + ] + + config = AutoConfig.from_pretrained(model, trust_remote_code=True) + if not getattr(config, "norm_type", None): + config.norm_type = "rms_norm" + + hf_model = AutoModel.from_pretrained(model, + torch_dtype=dtype, + trust_remote_code=True).to("cuda") + hf_outputs_per_image = [ + hf_model(pixel_value.to("cuda")).last_hidden_state + for pixel_value in pixel_values + ] + + vllm_model = InternVisionModel(config) + vllm_model.load_weights(hf_model.state_dict().items()) + + del hf_model + cleanup() + + vllm_model = vllm_model.to("cuda", dtype) + vllm_outputs_per_image = [ + vllm_model(pixel_values=pixel_value.to("cuda")) + for pixel_value in pixel_values + ] + del vllm_model + cleanup() + + cos_similar = nn.CosineSimilarity(dim=-1) + for vllm_output, hf_output in zip(vllm_outputs_per_image, + hf_outputs_per_image): + assert cos_similar(vllm_output, hf_output).mean() > 0.99 + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", [torch.half]) +@torch.inference_mode() +def test_models(dist_init, image_assets, model, dtype: str) -> None: + run_intern_vit_test( + image_assets, + model, + dtype=dtype, + )