[ci][distributed] add pipeline parallel correctness test (#6410)

This commit is contained in:
youkaichao 2024-07-16 15:44:22 -07:00 committed by GitHub
parent 978aed5300
commit 09c2eb85dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 118 deletions

View File

@ -72,7 +72,7 @@ steps:
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
@ -115,12 +115,7 @@ steps:
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- pytest -v -s distributed/test_pipeline_parallel.py
- label: Engine Test
mirror_hardwares: [amd]

View File

@ -1,25 +1,18 @@
import os
import openai # use the official client for correctness check
import pytest
from ..utils import RemoteOpenAIServer
# downloading lora to test lora requests
# any model with a chat template should work here
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0)))
CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0)))
TP_SIZE = int(os.getenv("TP_SIZE", 1))
PP_SIZE = int(os.getenv("PP_SIZE", 1))
pytestmark = pytest.mark.asyncio
@pytest.fixture(scope="module")
def server():
args = [
@pytest.mark.parametrize(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
pp_args = [
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
@ -32,109 +25,107 @@ def server():
"--distributed-executor-backend",
"ray",
]
# compare without pipeline parallelism
# NOTE: use mp backend for TP
# PP tests might involve multiple nodes, and ray might
# schedule all workers in a node other than the head node,
# which can cause the test to fail.
tp_args = [
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--tensor-parallel-size",
str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model
"--distributed-executor-backend",
"mp",
]
if CHUNKED_PREFILL:
args += [
"--enable-chunked-prefill",
]
pp_args.append("--enable-chunked-prefill")
tp_args.append("--enable-chunked-prefill")
if EAGER_MODE:
args += [
"--enforce-eager",
]
with RemoteOpenAIServer(args) as remote_server:
yield remote_server
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
results = []
for args in [pp_args, tp_args]:
with RemoteOpenAIServer(args) as server:
client = server.get_client()
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
# test models list
models = client.models.list()
models = models.data
served_model = models[0]
results.append({
"test": "models_list",
"id": served_model.id,
"root": served_model.root,
})
# test with text prompt
completion = client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
results.append({
"test": "single_completion",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test using token IDs
completion = client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_single_completion(server, client: openai.AsyncOpenAI,
model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
results.append({
"test": "token_ids",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test simple list
batch = client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
results.append({
"test": "simple_list",
"text0": batch.choices[0].text,
"text1": batch.choices[1].text,
})
# test streaming
batch = client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
results.append({
"test": "streaming",
"texts": texts,
})
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME],
)
async def test_batch_completions(server, client: openai.AsyncOpenAI,
model_name: str):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
n = len(results) // 2
pp_results = results[:n]
tp_results = results[n:]
for pp, tp in zip(pp_results, tp_results):
assert pp == tp

View File

@ -1,5 +1,7 @@
import asyncio
import os
import signal
import weakref
from functools import partial
from typing import Any, List, Optional
@ -78,6 +80,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
result_handler.start()
self.worker_monitor.start()
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)
def shutdown(signum, frame):
if executor := ref():
executor.shutdown()
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")