[ci][distributed] add pipeline parallel correctness test (#6410)
This commit is contained in:
parent
978aed5300
commit
09c2eb85dd
@ -72,7 +72,7 @@ steps:
|
||||
commands:
|
||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||
|
||||
@ -115,12 +115,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
commands:
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
- label: Engine Test
|
||||
mirror_hardwares: [amd]
|
||||
|
||||
@ -1,25 +1,18 @@
|
||||
import os
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
|
||||
from ..utils import RemoteOpenAIServer
|
||||
|
||||
# downloading lora to test lora requests
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
|
||||
EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0)))
|
||||
CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0)))
|
||||
TP_SIZE = int(os.getenv("TP_SIZE", 1))
|
||||
PP_SIZE = int(os.getenv("PP_SIZE", 1))
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
@pytest.mark.parametrize(
|
||||
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
|
||||
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
|
||||
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
|
||||
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
|
||||
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
|
||||
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
|
||||
])
|
||||
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
||||
pp_args = [
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
@ -32,109 +25,107 @@ def server():
|
||||
"--distributed-executor-backend",
|
||||
"ray",
|
||||
]
|
||||
|
||||
# compare without pipeline parallelism
|
||||
# NOTE: use mp backend for TP
|
||||
# PP tests might involve multiple nodes, and ray might
|
||||
# schedule all workers in a node other than the head node,
|
||||
# which can cause the test to fail.
|
||||
tp_args = [
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--tensor-parallel-size",
|
||||
str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
if CHUNKED_PREFILL:
|
||||
args += [
|
||||
"--enable-chunked-prefill",
|
||||
]
|
||||
pp_args.append("--enable-chunked-prefill")
|
||||
tp_args.append("--enable-chunked-prefill")
|
||||
if EAGER_MODE:
|
||||
args += [
|
||||
"--enforce-eager",
|
||||
]
|
||||
with RemoteOpenAIServer(args) as remote_server:
|
||||
yield remote_server
|
||||
pp_args.append("--enforce-eager")
|
||||
tp_args.append("--enforce-eager")
|
||||
|
||||
results = []
|
||||
for args in [pp_args, tp_args]:
|
||||
with RemoteOpenAIServer(args) as server:
|
||||
client = server.get_client()
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
# test models list
|
||||
models = client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
results.append({
|
||||
"test": "models_list",
|
||||
"id": served_model.id,
|
||||
"root": served_model.root,
|
||||
})
|
||||
|
||||
# test with text prompt
|
||||
completion = client.completions.create(model=MODEL_NAME,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
async def test_check_models(server, client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert all(model.root == MODEL_NAME for model in models)
|
||||
results.append({
|
||||
"test": "single_completion",
|
||||
"text": completion.choices[0].text,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
# test using token IDs
|
||||
completion = client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_single_completion(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
results.append({
|
||||
"test": "token_ids",
|
||||
"text": completion.choices[0].text,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
assert completion.choices[0].text is not None and len(
|
||||
completion.choices[0].text) >= 5
|
||||
assert completion.choices[0].finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
# test simple list
|
||||
batch = client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert completion.choices[0].text is not None and len(
|
||||
completion.choices[0].text) >= 5
|
||||
results.append({
|
||||
"test": "simple_list",
|
||||
"text0": batch.choices[0].text,
|
||||
"text1": batch.choices[1].text,
|
||||
})
|
||||
|
||||
# test streaming
|
||||
batch = client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 2
|
||||
for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
results.append({
|
||||
"test": "streaming",
|
||||
"texts": texts,
|
||||
})
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_batch_completions(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
# test simple list
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(batch.choices) == 2
|
||||
assert batch.choices[0].text == batch.choices[1].text
|
||||
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
|
||||
# for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
assert batch.choices[0].text != batch.choices[
|
||||
1].text, "beam search should be different"
|
||||
assert batch.choices[0].text == batch.choices[
|
||||
2].text, "two copies of the same prompt should be the same"
|
||||
assert batch.choices[1].text == batch.choices[
|
||||
3].text, "two copies of the same prompt should be the same"
|
||||
|
||||
# test streaming
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 2
|
||||
async for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
assert texts[0] == texts[1]
|
||||
n = len(results) // 2
|
||||
pp_results = results[:n]
|
||||
tp_results = results[n:]
|
||||
for pp, tp in zip(pp_results, tp_results):
|
||||
assert pp == tp
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import weakref
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@ -78,6 +80,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
result_handler.start()
|
||||
self.worker_monitor.start()
|
||||
|
||||
# Set up signal handlers to shutdown the executor cleanly
|
||||
# sometimes gc does not work well
|
||||
|
||||
# Use weakref to avoid holding a reference to self
|
||||
ref = weakref.ref(self)
|
||||
|
||||
def shutdown(signum, frame):
|
||||
if executor := ref():
|
||||
executor.shutdown()
|
||||
|
||||
signal.signal(signal.SIGINT, shutdown)
|
||||
signal.signal(signal.SIGTERM, shutdown)
|
||||
|
||||
self.driver_worker = self._create_worker(
|
||||
distributed_init_method=distributed_init_method)
|
||||
self._run_workers("init_device")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user