[Bugfix] Fix PP for Multi-Step (#8887)
This commit is contained in:
parent
39d3f8d94f
commit
19d02ff938
@ -142,3 +142,85 @@ async def test_multi_step(
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("tp_size, pp_size"), [
|
||||
(1, 2),
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step_pp_smoke(
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""
|
||||
Smoke test for the vLLM engine with multi-step scheduling in an
|
||||
OpenAI-protocol client/server environment.
|
||||
|
||||
This tests compares the outputs between multi-step scheduling and
|
||||
single-step scheduling. Notably, this test lets the engines generate
|
||||
more tokens (default is 5) and test for an exact match over all the
|
||||
tokens.
|
||||
|
||||
Args:
|
||||
tp_size: degree of tensor-parallelism
|
||||
pp_size: degree of pipeline-parallelism
|
||||
eager_mode
|
||||
"""
|
||||
|
||||
model = "JackFram/llama-160m"
|
||||
num_scheduler_steps = 8
|
||||
attention_backend = "FLASH_ATTN"
|
||||
max_num_seqs = 3
|
||||
|
||||
override_backend_env_variable(monkeypatch, attention_backend)
|
||||
|
||||
# Prompt from the ShareGPT dataset
|
||||
prompts = [
|
||||
"in the jtbd context whats a push?", # codespell:ignore
|
||||
"in the jtbd context whats a push?", # codespell:ignore
|
||||
"in the jtbd context whats a push?", # codespell:ignore
|
||||
"in the jtbd context whats a push?", # codespell:ignore
|
||||
]
|
||||
# Use varying max_tokens to introduce scheduling randomness.
|
||||
max_tokens = [10 * i for i in range(1, len(prompts) + 1)]
|
||||
assert len(prompts) == len(max_tokens)
|
||||
|
||||
test_args = [
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size), "--pipeline-parallel-size",
|
||||
str(pp_size), "--max-num-seqs",
|
||||
str(max_num_seqs)
|
||||
]
|
||||
|
||||
server_args = DEFAULT_SERVER_ARGS + test_args
|
||||
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||
["--num-scheduler-steps", f"{num_scheduler_steps}"] + \
|
||||
test_args
|
||||
|
||||
# Spin up client/server & issue completion API requests.
|
||||
# Default `max_wait_seconds` is 240 but was empirically
|
||||
# was raised 3x to 720 *just for this test* due to
|
||||
# observed timeouts in GHA CI
|
||||
ref_completions = await completions_with_server_args(
|
||||
prompts=prompts,
|
||||
model_name=model,
|
||||
server_cli_args=server_args,
|
||||
num_logprobs=None,
|
||||
max_wait_seconds=5 * 240,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
test_completions = await completions_with_server_args(
|
||||
prompts=prompts,
|
||||
model_name=model,
|
||||
server_cli_args=ms_server_args,
|
||||
num_logprobs=None,
|
||||
max_wait_seconds=5 * 240,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
# Assert multi-step scheduling produces identical tokens
|
||||
# to single-step scheduling.
|
||||
ref_generations = get_client_text_generations(ref_completions)
|
||||
test_generations = get_client_text_generations(test_completions)
|
||||
|
||||
assert ref_generations == test_generations
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import signal
|
||||
@ -7,7 +8,7 @@ import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
@ -476,7 +477,8 @@ async def completions_with_server_args(
|
||||
server_cli_args: List[str],
|
||||
num_logprobs: Optional[int],
|
||||
max_wait_seconds: int = 240,
|
||||
) -> Completion:
|
||||
max_tokens: Union[int, list] = 5,
|
||||
) -> List[Completion]:
|
||||
'''Construct a remote OpenAI server, obtain an async client to the
|
||||
server & invoke the completions API to obtain completions.
|
||||
|
||||
@ -487,37 +489,49 @@ async def completions_with_server_args(
|
||||
num_logprobs: Number of logprobs to report (or `None`)
|
||||
max_wait_seconds: timeout interval for bringing up server.
|
||||
Default: 240sec
|
||||
max_tokens: max_tokens value for each of the given input prompts.
|
||||
if only one max_token value is given, the same value is used
|
||||
for all the prompts.
|
||||
|
||||
Returns:
|
||||
OpenAI Completion instance
|
||||
'''
|
||||
|
||||
if isinstance(max_tokens, int):
|
||||
max_tokens = [max_tokens] * len(prompts)
|
||||
|
||||
assert len(max_tokens) == len(prompts)
|
||||
|
||||
outputs = None
|
||||
max_wait_seconds = 240 * 3 # 240 is default
|
||||
with RemoteOpenAIServer(model_name,
|
||||
server_cli_args,
|
||||
max_wait_seconds=max_wait_seconds) as server:
|
||||
client = server.get_async_client()
|
||||
outputs = await client.completions.create(model=model_name,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stream=False,
|
||||
max_tokens=5,
|
||||
logprobs=num_logprobs)
|
||||
outputs = [ client.completions.create(model=model_name,
|
||||
prompt=[p],
|
||||
temperature=0,
|
||||
stream=False,
|
||||
max_tokens=max_tok,
|
||||
logprobs=num_logprobs) \
|
||||
for p, max_tok in zip(prompts, max_tokens) ]
|
||||
outputs = await asyncio.gather(*outputs)
|
||||
|
||||
assert outputs is not None, "Completion API call failed."
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def get_client_text_generations(completions: Completion) -> List[str]:
|
||||
def get_client_text_generations(completions: List[Completion]) -> List[str]:
|
||||
'''Extract generated tokens from the output of a
|
||||
request made to an Open-AI-protocol completions endpoint.
|
||||
'''
|
||||
return [x.text for x in completions.choices]
|
||||
assert all([len(x.choices) == 1 for x in completions])
|
||||
return [x.choices[0].text for x in completions]
|
||||
|
||||
|
||||
def get_client_text_logprob_generations(
|
||||
completions: Completion) -> List[TextTextLogprobs]:
|
||||
completions: List[Completion]) -> List[TextTextLogprobs]:
|
||||
'''Operates on the output of a request made to an Open-AI-protocol
|
||||
completions endpoint; obtains top-rank logprobs for each token in
|
||||
each :class:`SequenceGroup`
|
||||
@ -526,4 +540,4 @@ def get_client_text_logprob_generations(
|
||||
text = ''.join(text_generations)
|
||||
return [(text_generations, text,
|
||||
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
||||
for x in completions.choices]
|
||||
for completion in completions for x in completion.choices]
|
||||
|
||||
@ -97,6 +97,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
assert len(seqs) == 1, (
|
||||
"Beam search not supported in multi-step decoding.")
|
||||
seq = seqs[0]
|
||||
seq_id = seq.seq_id
|
||||
assert all(
|
||||
[seq_id == output.samples[0].parent_seq_id for output in outputs])
|
||||
|
||||
if is_async:
|
||||
# Async case: We process tokens one by one. Here, we know the token
|
||||
|
||||
@ -1007,8 +1007,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
# Used to cache python objects
|
||||
self.inter_data_cache: Dict[int, PyObjectCache] = {}
|
||||
|
||||
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
||||
# SequenceGroupToSample object. In Pipeline-Parallel, we have
|
||||
# more than 1 Scheduler, resulting in a potential back-to-back
|
||||
# prepare_model_inputs() call. This clobbers the cached
|
||||
# SequenceGroupToSample objects, as we reset the cache during
|
||||
# every prepare_model_inputs() call.
|
||||
self.sampling_metadata_cache: SamplingMetadataCache = \
|
||||
SamplingMetadataCache()
|
||||
SamplingMetadataCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
|
||||
@ -326,7 +326,14 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
self.is_multi_step = self.scheduler_config.is_multi_step
|
||||
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
self.pythonization_cache = PythonizationCache()
|
||||
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
||||
# SequenceOutput and CompletionSequenceGroupOutput object.
|
||||
# When cache-reset happens at the last step of a multi-step
|
||||
# execution, there may be other on-going single-step/multi-step
|
||||
# executions. The current caching implementation does not check
|
||||
# for this.
|
||||
self.pythonization_cache = PythonizationCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
@functools.cached_property
|
||||
def _copy_stream(self):
|
||||
@ -577,7 +584,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
if model_input.is_last_step:
|
||||
outputs = self._final_process_outputs(
|
||||
model_input, model_input.base_output_proc_callback)
|
||||
self.pythonization_cache.reset()
|
||||
if self.pythonization_cache:
|
||||
self.pythonization_cache.reset()
|
||||
return outputs
|
||||
|
||||
# should be [SamplerOutput]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user