[Test] Add basic correctness test (#2908)
This commit is contained in:
parent
537c9755a7
commit
a61f0521b8
@ -11,8 +11,16 @@ steps:
|
|||||||
- label: AsyncEngine Test
|
- label: AsyncEngine Test
|
||||||
command: pytest -v -s async_engine
|
command: pytest -v -s async_engine
|
||||||
|
|
||||||
- label: Distributed Test
|
- label: Basic Correctness Test
|
||||||
command: pytest -v -s test_comm_ops.py
|
command: pytest -v -s --forked basic_correctness
|
||||||
|
|
||||||
|
- label: Distributed Comm Ops Test
|
||||||
|
command: pytest -v -s --forked test_comm_ops.py
|
||||||
|
working_dir: "/vllm-workspace/tests/distributed"
|
||||||
|
num_gpus: 2 # only support 1 or 2 for now.
|
||||||
|
|
||||||
|
- label: Distributed Correctness Test
|
||||||
|
command: pytest -v -s --forked test_basic_distributed_correctness.py
|
||||||
working_dir: "/vllm-workspace/tests/distributed"
|
working_dir: "/vllm-workspace/tests/distributed"
|
||||||
num_gpus: 2 # only support 1 or 2 for now.
|
num_gpus: 2 # only support 1 or 2 for now.
|
||||||
|
|
||||||
|
|||||||
38
tests/basic_correctness/test_basic_correctness.py
Normal file
38
tests/basic_correctness/test_basic_correctness.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||||
|
|
||||||
|
Run `pytest tests/basic_correctness/test_basic_correctness.py --forked`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
|
def test_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||||
|
assert hf_output_str == vllm_output_str, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||||
@ -165,6 +165,7 @@ class VllmRunner:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
tokenizer_name: Optional[str] = None,
|
tokenizer_name: Optional[str] = None,
|
||||||
dtype: str = "half",
|
dtype: str = "half",
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = LLM(
|
self.model = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -172,6 +173,7 @@ class VllmRunner:
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
swap_space=0,
|
swap_space=0,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
|||||||
41
tests/distributed/test_basic_distributed_correctness.py
Normal file
41
tests/distributed/test_basic_distributed_correctness.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||||
|
|
||||||
|
Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
|
def test_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2)
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||||
|
assert hf_output_str == vllm_output_str, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||||
Loading…
Reference in New Issue
Block a user