[ci][distributed] add pipeline parallel correctness test (#6410)

This commit is contained in:
youkaichao 2024-07-16 15:44:22 -07:00 committed by GitHub
parent 978aed5300
commit 09c2eb85dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 118 deletions

View File

@ -72,7 +72,7 @@ steps:
commands: commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py - pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
@ -115,12 +115,7 @@ steps:
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 4 num_gpus: 4
commands: commands:
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py - pytest -v -s distributed/test_pipeline_parallel.py
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- label: Engine Test - label: Engine Test
mirror_hardwares: [amd] mirror_hardwares: [amd]

View File

@ -1,25 +1,18 @@
import os
import openai # use the official client for correctness check
import pytest import pytest
from ..utils import RemoteOpenAIServer from ..utils import RemoteOpenAIServer
# downloading lora to test lora requests
# any model with a chat template should work here @pytest.mark.parametrize(
MODEL_NAME = "meta-llama/Meta-Llama-3-8B" "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0))) (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0))) (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
TP_SIZE = int(os.getenv("TP_SIZE", 1)) (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
PP_SIZE = int(os.getenv("PP_SIZE", 1)) (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
pytestmark = pytest.mark.asyncio ])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
pp_args = [
@pytest.fixture(scope="module")
def server():
args = [
"--model", "--model",
MODEL_NAME, MODEL_NAME,
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
@ -32,109 +25,107 @@ def server():
"--distributed-executor-backend", "--distributed-executor-backend",
"ray", "ray",
] ]
# compare without pipeline parallelism
# NOTE: use mp backend for TP
# PP tests might involve multiple nodes, and ray might
# schedule all workers in a node other than the head node,
# which can cause the test to fail.
tp_args = [
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--tensor-parallel-size",
str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model
"--distributed-executor-backend",
"mp",
]
if CHUNKED_PREFILL: if CHUNKED_PREFILL:
args += [ pp_args.append("--enable-chunked-prefill")
"--enable-chunked-prefill", tp_args.append("--enable-chunked-prefill")
]
if EAGER_MODE: if EAGER_MODE:
args += [ pp_args.append("--enforce-eager")
"--enforce-eager", tp_args.append("--enforce-eager")
]
with RemoteOpenAIServer(args) as remote_server:
yield remote_server
results = []
for args in [pp_args, tp_args]:
with RemoteOpenAIServer(args) as server:
client = server.get_client()
@pytest.fixture(scope="module") # test models list
def client(server): models = client.models.list()
return server.get_async_client() models = models.data
served_model = models[0]
results.append({
"test": "models_list",
"id": served_model.id,
"root": served_model.root,
})
# test with text prompt
completion = client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
async def test_check_models(server, client: openai.AsyncOpenAI): results.append({
models = await client.models.list() "test": "single_completion",
models = models.data "text": completion.choices[0].text,
served_model = models[0] "finish_reason": completion.choices[0].finish_reason,
assert served_model.id == MODEL_NAME "usage": completion.usage,
assert all(model.root == MODEL_NAME for model in models) })
# test using token IDs
completion = client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
@pytest.mark.parametrize( results.append({
"model_name", "test": "token_ids",
[MODEL_NAME], "text": completion.choices[0].text,
) "finish_reason": completion.choices[0].finish_reason,
async def test_single_completion(server, client: openai.AsyncOpenAI, "usage": completion.usage,
model_name: str): })
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None # test simple list
assert completion.choices is not None and len(completion.choices) == 1 batch = client.completions.create(
assert completion.choices[0].text is not None and len( model=MODEL_NAME,
completion.choices[0].text) >= 5 prompt=["Hello, my name is", "Hello, my name is"],
assert completion.choices[0].finish_reason == "length" max_tokens=5,
assert completion.usage == openai.types.CompletionUsage( temperature=0.0,
completion_tokens=5, prompt_tokens=6, total_tokens=11) )
# test using token IDs results.append({
completion = await client.completions.create( "test": "simple_list",
model=MODEL_NAME, "text0": batch.choices[0].text,
prompt=[0, 0, 0, 0, 0], "text1": batch.choices[1].text,
max_tokens=5, })
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
# test streaming
batch = client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
results.append({
"test": "streaming",
"texts": texts,
})
@pytest.mark.parametrize( n = len(results) // 2
# just test 1 lora hereafter pp_results = results[:n]
"model_name", tp_results = results[n:]
[MODEL_NAME], for pp, tp in zip(pp_results, tp_results):
) assert pp == tp
async def test_batch_completions(server, client: openai.AsyncOpenAI,
model_name: str):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]

View File

@ -1,5 +1,7 @@
import asyncio import asyncio
import os import os
import signal
import weakref
from functools import partial from functools import partial
from typing import Any, List, Optional from typing import Any, List, Optional
@ -78,6 +80,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
result_handler.start() result_handler.start()
self.worker_monitor.start() self.worker_monitor.start()
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)
def shutdown(signum, frame):
if executor := ref():
executor.shutdown()
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
self.driver_worker = self._create_worker( self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method) distributed_init_method=distributed_init_method)
self._run_workers("init_device") self._run_workers("init_device")