From 09c2eb85ddd3b2585979f4cd9cc97168d86718b6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 16 Jul 2024 15:44:22 -0700 Subject: [PATCH] [ci][distributed] add pipeline parallel correctness test (#6410) --- .buildkite/test-pipeline.yaml | 9 +- tests/distributed/test_pipeline_parallel.py | 213 ++++++++++---------- vllm/executor/multiproc_gpu_executor.py | 15 ++ 3 files changed, 119 insertions(+), 118 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index cd3a5e80..445d74d6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -72,7 +72,7 @@ steps: commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py + - pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py @@ -115,12 +115,7 @@ steps: working_dir: "/vllm-workspace/tests" num_gpus: 4 commands: - - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py - - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py - - TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py - - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py - - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py - + - pytest -v -s distributed/test_pipeline_parallel.py - label: Engine Test mirror_hardwares: [amd] diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 2d9f6379..5e824b0f 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -1,25 +1,18 @@ -import os - -import openai # use the official client for correctness check import pytest from ..utils import RemoteOpenAIServer -# downloading lora to test lora requests -# any model with a chat template should work here -MODEL_NAME = "meta-llama/Meta-Llama-3-8B" -EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0))) -CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0))) -TP_SIZE = int(os.getenv("TP_SIZE", 1)) -PP_SIZE = int(os.getenv("PP_SIZE", 1)) - -pytestmark = pytest.mark.asyncio - - -@pytest.fixture(scope="module") -def server(): - args = [ +@pytest.mark.parametrize( + "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [ + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"), + ]) +def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): + pp_args = [ "--model", MODEL_NAME, # use half precision for speed and memory savings in CI environment @@ -32,109 +25,107 @@ def server(): "--distributed-executor-backend", "ray", ] + + # compare without pipeline parallelism + # NOTE: use mp backend for TP + # PP tests might involve multiple nodes, and ray might + # schedule all workers in a node other than the head node, + # which can cause the test to fail. + tp_args = [ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--tensor-parallel-size", + str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model + "--distributed-executor-backend", + "mp", + ] if CHUNKED_PREFILL: - args += [ - "--enable-chunked-prefill", - ] + pp_args.append("--enable-chunked-prefill") + tp_args.append("--enable-chunked-prefill") if EAGER_MODE: - args += [ - "--enforce-eager", - ] - with RemoteOpenAIServer(args) as remote_server: - yield remote_server + pp_args.append("--enforce-eager") + tp_args.append("--enforce-eager") + results = [] + for args in [pp_args, tp_args]: + with RemoteOpenAIServer(args) as server: + client = server.get_client() -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() + # test models list + models = client.models.list() + models = models.data + served_model = models[0] + results.append({ + "test": "models_list", + "id": served_model.id, + "root": served_model.root, + }) + # test with text prompt + completion = client.completions.create(model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) -async def test_check_models(server, client: openai.AsyncOpenAI): - models = await client.models.list() - models = models.data - served_model = models[0] - assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) + results.append({ + "test": "single_completion", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + }) + # test using token IDs + completion = client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_single_completion(server, client: openai.AsyncOpenAI, - model_name: str): - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) + results.append({ + "test": "token_ids", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + }) - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) + # test simple list + batch = client.completions.create( + model=MODEL_NAME, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + ) - # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + results.append({ + "test": "simple_list", + "text0": batch.choices[0].text, + "text1": batch.choices[1].text, + }) + # test streaming + batch = client.completions.create( + model=MODEL_NAME, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + results.append({ + "test": "streaming", + "texts": texts, + }) -@pytest.mark.parametrize( - # just test 1 lora hereafter - "model_name", - [MODEL_NAME], -) -async def test_batch_completions(server, client: openai.AsyncOpenAI, - model_name: str): - # test simple list - batch = await client.completions.create( - model=model_name, - prompt=["Hello, my name is", "Hello, my name is"], - max_tokens=5, - temperature=0.0, - ) - assert len(batch.choices) == 2 - assert batch.choices[0].text == batch.choices[1].text - - # test n = 2 - batch = await client.completions.create( - model=model_name, - prompt=["Hello, my name is", "Hello, my name is"], - n=2, - max_tokens=5, - temperature=0.0, - extra_body=dict( - # NOTE: this has to be true for n > 1 in vLLM, but not necessary - # for official client. - use_beam_search=True), - ) - assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" - - # test streaming - batch = await client.completions.create( - model=model_name, - prompt=["Hello, my name is", "Hello, my name is"], - max_tokens=5, - temperature=0.0, - stream=True, - ) - texts = [""] * 2 - async for chunk in batch: - assert len(chunk.choices) == 1 - choice = chunk.choices[0] - texts[choice.index] += choice.text - assert texts[0] == texts[1] + n = len(results) // 2 + pp_results = results[:n] + tp_results = results[n:] + for pp, tp in zip(pp_results, tp_results): + assert pp == tp diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index a0e248b2..01ed9d12 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,5 +1,7 @@ import asyncio import os +import signal +import weakref from functools import partial from typing import Any, List, Optional @@ -78,6 +80,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): result_handler.start() self.worker_monitor.start() + # Set up signal handlers to shutdown the executor cleanly + # sometimes gc does not work well + + # Use weakref to avoid holding a reference to self + ref = weakref.ref(self) + + def shutdown(signum, frame): + if executor := ref(): + executor.shutdown() + + signal.signal(signal.SIGINT, shutdown) + signal.signal(signal.SIGTERM, shutdown) + self.driver_worker = self._create_worker( distributed_init_method=distributed_init_method) self._run_workers("init_device")