[Bugfix] Fix PP for Multi-Step (#8887)
This commit is contained in:
parent
39d3f8d94f
commit
19d02ff938
@ -142,3 +142,85 @@ async def test_multi_step(
|
|||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("tp_size, pp_size"), [
|
||||||
|
(1, 2),
|
||||||
|
])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_step_pp_smoke(
|
||||||
|
tp_size: int,
|
||||||
|
pp_size: int,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Smoke test for the vLLM engine with multi-step scheduling in an
|
||||||
|
OpenAI-protocol client/server environment.
|
||||||
|
|
||||||
|
This tests compares the outputs between multi-step scheduling and
|
||||||
|
single-step scheduling. Notably, this test lets the engines generate
|
||||||
|
more tokens (default is 5) and test for an exact match over all the
|
||||||
|
tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tp_size: degree of tensor-parallelism
|
||||||
|
pp_size: degree of pipeline-parallelism
|
||||||
|
eager_mode
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = "JackFram/llama-160m"
|
||||||
|
num_scheduler_steps = 8
|
||||||
|
attention_backend = "FLASH_ATTN"
|
||||||
|
max_num_seqs = 3
|
||||||
|
|
||||||
|
override_backend_env_variable(monkeypatch, attention_backend)
|
||||||
|
|
||||||
|
# Prompt from the ShareGPT dataset
|
||||||
|
prompts = [
|
||||||
|
"in the jtbd context whats a push?", # codespell:ignore
|
||||||
|
"in the jtbd context whats a push?", # codespell:ignore
|
||||||
|
"in the jtbd context whats a push?", # codespell:ignore
|
||||||
|
"in the jtbd context whats a push?", # codespell:ignore
|
||||||
|
]
|
||||||
|
# Use varying max_tokens to introduce scheduling randomness.
|
||||||
|
max_tokens = [10 * i for i in range(1, len(prompts) + 1)]
|
||||||
|
assert len(prompts) == len(max_tokens)
|
||||||
|
|
||||||
|
test_args = [
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
str(tp_size), "--pipeline-parallel-size",
|
||||||
|
str(pp_size), "--max-num-seqs",
|
||||||
|
str(max_num_seqs)
|
||||||
|
]
|
||||||
|
|
||||||
|
server_args = DEFAULT_SERVER_ARGS + test_args
|
||||||
|
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||||
|
["--num-scheduler-steps", f"{num_scheduler_steps}"] + \
|
||||||
|
test_args
|
||||||
|
|
||||||
|
# Spin up client/server & issue completion API requests.
|
||||||
|
# Default `max_wait_seconds` is 240 but was empirically
|
||||||
|
# was raised 3x to 720 *just for this test* due to
|
||||||
|
# observed timeouts in GHA CI
|
||||||
|
ref_completions = await completions_with_server_args(
|
||||||
|
prompts=prompts,
|
||||||
|
model_name=model,
|
||||||
|
server_cli_args=server_args,
|
||||||
|
num_logprobs=None,
|
||||||
|
max_wait_seconds=5 * 240,
|
||||||
|
max_tokens=max_tokens)
|
||||||
|
|
||||||
|
test_completions = await completions_with_server_args(
|
||||||
|
prompts=prompts,
|
||||||
|
model_name=model,
|
||||||
|
server_cli_args=ms_server_args,
|
||||||
|
num_logprobs=None,
|
||||||
|
max_wait_seconds=5 * 240,
|
||||||
|
max_tokens=max_tokens)
|
||||||
|
|
||||||
|
# Assert multi-step scheduling produces identical tokens
|
||||||
|
# to single-step scheduling.
|
||||||
|
ref_generations = get_client_text_generations(ref_completions)
|
||||||
|
test_generations = get_client_text_generations(test_completions)
|
||||||
|
|
||||||
|
assert ref_generations == test_generations
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
@ -7,7 +8,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
@ -476,7 +477,8 @@ async def completions_with_server_args(
|
|||||||
server_cli_args: List[str],
|
server_cli_args: List[str],
|
||||||
num_logprobs: Optional[int],
|
num_logprobs: Optional[int],
|
||||||
max_wait_seconds: int = 240,
|
max_wait_seconds: int = 240,
|
||||||
) -> Completion:
|
max_tokens: Union[int, list] = 5,
|
||||||
|
) -> List[Completion]:
|
||||||
'''Construct a remote OpenAI server, obtain an async client to the
|
'''Construct a remote OpenAI server, obtain an async client to the
|
||||||
server & invoke the completions API to obtain completions.
|
server & invoke the completions API to obtain completions.
|
||||||
|
|
||||||
@ -487,37 +489,49 @@ async def completions_with_server_args(
|
|||||||
num_logprobs: Number of logprobs to report (or `None`)
|
num_logprobs: Number of logprobs to report (or `None`)
|
||||||
max_wait_seconds: timeout interval for bringing up server.
|
max_wait_seconds: timeout interval for bringing up server.
|
||||||
Default: 240sec
|
Default: 240sec
|
||||||
|
max_tokens: max_tokens value for each of the given input prompts.
|
||||||
|
if only one max_token value is given, the same value is used
|
||||||
|
for all the prompts.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OpenAI Completion instance
|
OpenAI Completion instance
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
if isinstance(max_tokens, int):
|
||||||
|
max_tokens = [max_tokens] * len(prompts)
|
||||||
|
|
||||||
|
assert len(max_tokens) == len(prompts)
|
||||||
|
|
||||||
outputs = None
|
outputs = None
|
||||||
max_wait_seconds = 240 * 3 # 240 is default
|
max_wait_seconds = 240 * 3 # 240 is default
|
||||||
with RemoteOpenAIServer(model_name,
|
with RemoteOpenAIServer(model_name,
|
||||||
server_cli_args,
|
server_cli_args,
|
||||||
max_wait_seconds=max_wait_seconds) as server:
|
max_wait_seconds=max_wait_seconds) as server:
|
||||||
client = server.get_async_client()
|
client = server.get_async_client()
|
||||||
outputs = await client.completions.create(model=model_name,
|
outputs = [ client.completions.create(model=model_name,
|
||||||
prompt=prompts,
|
prompt=[p],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=5,
|
max_tokens=max_tok,
|
||||||
logprobs=num_logprobs)
|
logprobs=num_logprobs) \
|
||||||
|
for p, max_tok in zip(prompts, max_tokens) ]
|
||||||
|
outputs = await asyncio.gather(*outputs)
|
||||||
|
|
||||||
assert outputs is not None, "Completion API call failed."
|
assert outputs is not None, "Completion API call failed."
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def get_client_text_generations(completions: Completion) -> List[str]:
|
def get_client_text_generations(completions: List[Completion]) -> List[str]:
|
||||||
'''Extract generated tokens from the output of a
|
'''Extract generated tokens from the output of a
|
||||||
request made to an Open-AI-protocol completions endpoint.
|
request made to an Open-AI-protocol completions endpoint.
|
||||||
'''
|
'''
|
||||||
return [x.text for x in completions.choices]
|
assert all([len(x.choices) == 1 for x in completions])
|
||||||
|
return [x.choices[0].text for x in completions]
|
||||||
|
|
||||||
|
|
||||||
def get_client_text_logprob_generations(
|
def get_client_text_logprob_generations(
|
||||||
completions: Completion) -> List[TextTextLogprobs]:
|
completions: List[Completion]) -> List[TextTextLogprobs]:
|
||||||
'''Operates on the output of a request made to an Open-AI-protocol
|
'''Operates on the output of a request made to an Open-AI-protocol
|
||||||
completions endpoint; obtains top-rank logprobs for each token in
|
completions endpoint; obtains top-rank logprobs for each token in
|
||||||
each :class:`SequenceGroup`
|
each :class:`SequenceGroup`
|
||||||
@ -526,4 +540,4 @@ def get_client_text_logprob_generations(
|
|||||||
text = ''.join(text_generations)
|
text = ''.join(text_generations)
|
||||||
return [(text_generations, text,
|
return [(text_generations, text,
|
||||||
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
||||||
for x in completions.choices]
|
for completion in completions for x in completion.choices]
|
||||||
|
|||||||
@ -97,6 +97,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
assert len(seqs) == 1, (
|
assert len(seqs) == 1, (
|
||||||
"Beam search not supported in multi-step decoding.")
|
"Beam search not supported in multi-step decoding.")
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
|
seq_id = seq.seq_id
|
||||||
|
assert all(
|
||||||
|
[seq_id == output.samples[0].parent_seq_id for output in outputs])
|
||||||
|
|
||||||
if is_async:
|
if is_async:
|
||||||
# Async case: We process tokens one by one. Here, we know the token
|
# Async case: We process tokens one by one. Here, we know the token
|
||||||
|
|||||||
@ -1007,8 +1007,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
# Used to cache python objects
|
# Used to cache python objects
|
||||||
self.inter_data_cache: Dict[int, PyObjectCache] = {}
|
self.inter_data_cache: Dict[int, PyObjectCache] = {}
|
||||||
|
|
||||||
|
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
||||||
|
# SequenceGroupToSample object. In Pipeline-Parallel, we have
|
||||||
|
# more than 1 Scheduler, resulting in a potential back-to-back
|
||||||
|
# prepare_model_inputs() call. This clobbers the cached
|
||||||
|
# SequenceGroupToSample objects, as we reset the cache during
|
||||||
|
# every prepare_model_inputs() call.
|
||||||
self.sampling_metadata_cache: SamplingMetadataCache = \
|
self.sampling_metadata_cache: SamplingMetadataCache = \
|
||||||
SamplingMetadataCache()
|
SamplingMetadataCache() \
|
||||||
|
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
logger.info("Starting to load model %s...", self.model_config.model)
|
logger.info("Starting to load model %s...", self.model_config.model)
|
||||||
|
|||||||
@ -326,7 +326,14 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
self.is_multi_step = self.scheduler_config.is_multi_step
|
self.is_multi_step = self.scheduler_config.is_multi_step
|
||||||
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
self.pythonization_cache = PythonizationCache()
|
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
||||||
|
# SequenceOutput and CompletionSequenceGroupOutput object.
|
||||||
|
# When cache-reset happens at the last step of a multi-step
|
||||||
|
# execution, there may be other on-going single-step/multi-step
|
||||||
|
# executions. The current caching implementation does not check
|
||||||
|
# for this.
|
||||||
|
self.pythonization_cache = PythonizationCache() \
|
||||||
|
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def _copy_stream(self):
|
def _copy_stream(self):
|
||||||
@ -577,7 +584,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
if model_input.is_last_step:
|
if model_input.is_last_step:
|
||||||
outputs = self._final_process_outputs(
|
outputs = self._final_process_outputs(
|
||||||
model_input, model_input.base_output_proc_callback)
|
model_input, model_input.base_output_proc_callback)
|
||||||
self.pythonization_cache.reset()
|
if self.pythonization_cache:
|
||||||
|
self.pythonization_cache.reset()
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# should be [SamplerOutput]
|
# should be [SamplerOutput]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user