From a61f0521b8d0d53a91951bb56789ead397d5cd83 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 18 Feb 2024 16:44:50 -0800 Subject: [PATCH] [Test] Add basic correctness test (#2908) --- .buildkite/test-pipeline.yaml | 12 +++++- .../test_basic_correctness.py | 38 +++++++++++++++++ tests/conftest.py | 2 + .../test_basic_distributed_correctness.py | 41 +++++++++++++++++++ 4 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 tests/basic_correctness/test_basic_correctness.py create mode 100644 tests/distributed/test_basic_distributed_correctness.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2e417ef9..a91dcdfa 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -11,8 +11,16 @@ steps: - label: AsyncEngine Test command: pytest -v -s async_engine -- label: Distributed Test - command: pytest -v -s test_comm_ops.py +- label: Basic Correctness Test + command: pytest -v -s --forked basic_correctness + +- label: Distributed Comm Ops Test + command: pytest -v -s --forked test_comm_ops.py + working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 2 # only support 1 or 2 for now. + +- label: Distributed Correctness Test + command: pytest -v -s --forked test_basic_distributed_correctness.py working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 # only support 1 or 2 for now. diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py new file mode 100644 index 00000000..fe67e0f2 --- /dev/null +++ b/tests/basic_correctness/test_basic_correctness.py @@ -0,0 +1,38 @@ +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/basic_correctness/test_basic_correctness.py --forked`. +""" +import pytest + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/conftest.py b/tests/conftest.py index 8d6afdbd..941d48ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,6 +165,7 @@ class VllmRunner: model_name: str, tokenizer_name: Optional[str] = None, dtype: str = "half", + tensor_parallel_size: int = 1, ) -> None: self.model = LLM( model=model_name, @@ -172,6 +173,7 @@ class VllmRunner: trust_remote_code=True, dtype=dtype, swap_space=0, + tensor_parallel_size=tensor_parallel_size, ) def generate( diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py new file mode 100644 index 00000000..82075356 --- /dev/null +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -0,0 +1,41 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. + +Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`. +""" +import pytest +import torch + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", +] + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")