[Core] Optimize Async + Multi-step (#8050)

This commit is contained in:
Alexander Matveev 2024-09-03 14:50:29 -04:00 committed by GitHub
parent 95a178f861
commit 6d646d08a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 325 additions and 247 deletions

View File

@ -103,13 +103,13 @@ async def test_multi_step(
model,
server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)
max_wait_seconds=5 * 240)
test_completions = await completions_with_server_args(
prompts,
model,
ms_server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)
max_wait_seconds=5 * 240)
# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.

View File

@ -280,40 +280,27 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))
self._process_model_outputs(ctx=ctx)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
@ -351,26 +338,20 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]
execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
output = []
# Finish the current step for all the sequence groups.
@ -384,24 +365,22 @@ class _AsyncLLMEngine(LLMEngine):
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)
self._process_model_outputs(ctx=ctx)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
@ -411,17 +390,12 @@ class _AsyncLLMEngine(LLMEngine):
else:
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []
return ctx.request_outputs
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
return ctx.request_outputs
@ -640,6 +614,17 @@ class AsyncLLMEngine:
self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self.use_process_request_outputs_callback = not self.engine_use_ray
if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
self.process_request_outputs
if self.engine_use_ray:
print_warning_once(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
@ -883,13 +868,27 @@ class AsyncLLMEngine:
request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams.
finished = True
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
if not self.use_process_request_outputs_callback:
all_finished = self.process_request_outputs(request_outputs)
else:
# For callback case, we only need to detect when all
# requests are finished
all_finished = all(request_output.finished
for request_output in request_outputs)
return not all_finished
def process_request_outputs(self, request_outputs) -> bool:
# Put the outputs into the corresponding streams.
all_finished = True
for request_output in request_outputs:
self._request_tracker.process_request_output(
request_output, verbose=self.log_requests)
finished = finished and request_output.finished
all_finished = all_finished and request_output.finished
return not finished
return all_finished
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:

View File

@ -93,13 +93,14 @@ class SchedulerOutputState:
@dataclass
class SchedulerContext:
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
List[SequenceGroupMetadata],
SchedulerOutputs]] = field(
default_factory=lambda: deque())
List[SequenceGroupMetadata], SchedulerOutputs,
bool,
bool]] = field(default_factory=lambda: deque())
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = field(
default_factory=lambda: [])
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
class LLMEngine:
@ -357,6 +358,26 @@ class LLMEngine:
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callbacks = [
functools.partial(self._process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback = None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
@ -364,9 +385,7 @@ class LLMEngine:
Scheduler(
scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size,
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=True)
self.async_callbacks[v_id]
if model_config.use_async_output_proc else None)
for v_id in range(parallel_config.pipeline_parallel_size)
]
@ -417,30 +436,6 @@ class LLMEngine:
),
))
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callback = [
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=True)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callback_multi_step = [
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=False)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@ -1249,11 +1244,7 @@ class LLMEngine:
return
def _process_model_outputs(self,
virtual_engine: int,
is_async: bool,
sampler_output: Optional[SamplerOutput] = None,
is_last_output: bool = False) -> None:
def _process_model_outputs(self, ctx: SchedulerContext) -> None:
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
@ -1273,24 +1264,12 @@ class LLMEngine:
"""
now = time.time()
is_multi_step = sampler_output is not None
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
if len(ctx.output_queue) == 0:
return None
if is_multi_step:
# Async + multi-step case
(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue[0]
assert outputs is None
outputs = [sampler_output]
else:
# Async standard case
(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue.popleft()
# Get pending async postprocessor
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step) = ctx.output_queue.popleft()
assert outputs is not None
# Sanity check
@ -1306,6 +1285,7 @@ class LLMEngine:
outputs_by_sequence_group = outputs
finished_before: List[int] = []
finished_now: List[int] = []
for i, seq_group_meta in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
@ -1343,26 +1323,44 @@ class LLMEngine:
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, output)
continue
else:
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(
seq_group, output, is_async)
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, output,
is_async)
if seq_group.is_finished():
finished_now.append(i)
# For async + multi-step, free finished seqs and create outputs
# only on the final step.
if is_multi_step and not is_last_output:
return
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# Create the outputs.
for i, _ in enumerate(seq_group_metadata_list):
# Generate outputs for the requests that finished this iteration
for i in finished_now:
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
if not is_multi_step and i in finished_before:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
# Free currently finished requests
if finished_now:
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# For multi-step, do not create outputs each iteration
if not is_last_step:
# Immediately process request outputs here (if callback is given)
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
return
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
# must match with the indices
for i, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
if i in finished_before or i in finished_now:
continue # Avoids double processing
seq_group = scheduled_seq_group.seq_group
@ -1376,11 +1374,15 @@ class LLMEngine:
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
# For async + multi-step, do stats only on the last output.
# Otherwise, do stats if the execution is async
do_stats = is_multi_step or is_async
# Immediately process request outputs here (if callback is given)
if (ctx.request_outputs
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
if do_stats:
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if is_async:
# Log stats.
self.do_log_stats(scheduler_outputs, outputs, finished_before)
@ -1485,40 +1487,26 @@ class LLMEngine:
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))
self._process_model_outputs(ctx=ctx)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
@ -1555,13 +1543,8 @@ class LLMEngine:
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]
execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
@ -1573,10 +1556,8 @@ class LLMEngine:
else:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
output = []
@ -1590,28 +1571,24 @@ class LLMEngine:
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
# Add results to the output_queue
# (for async or non-async postprocessing)
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
# Add results to the output_queue
is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))
if output and allow_async_output_proc:
assert len(output) == 1, (
"Multi step decoding does not work "
"with async output processing.")
if output and allow_async_output_proc:
assert len(output) == 1, (
"Async postprocessor expects only a single output set")
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# Check if need to run the usual non-async path
if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)
self._process_model_outputs(ctx=ctx)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
@ -1620,17 +1597,12 @@ class LLMEngine:
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []
return ctx.request_outputs
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are

View File

@ -85,9 +85,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
# TODO: Add support for async if necessary
assert not is_async
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
# if a client disconnects from the api server.
@ -101,19 +98,41 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples = [output.samples[0] for output in outputs]
if is_async:
# Async case: We process tokens one by one. Here, we know the token
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
self._process_decode_and_stop(seq, sequence_group.sampling_params)
else:
# Standard multi-step case
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
assert valid_samples
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples = [output.samples[0] for output in outputs]
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
new_char_count = 0
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params,
)
def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
@ -151,16 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs=output_logprob,
)
new_char_count = 0
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
self._process_decode_and_stop(seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params,
)
if seq.is_finished():
break

View File

@ -1225,7 +1225,6 @@ class ExecuteModelRequest(
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
async_callback: Optional[Callable] = None
use_async_and_multi_step: bool = False
@property
def is_first_multi_step(self) -> bool:
@ -1272,5 +1271,4 @@ class ExecuteModelRequest(
finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback,
use_async_and_multi_step=self.use_async_and_multi_step)
async_callback=self.async_callback)

View File

@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY, InputRegistry
@ -96,7 +97,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
async_callback: Optional[Callable] = None
use_async_and_multi_step: bool = False
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import PyObjectCache
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
@ -37,6 +38,29 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
def seq_output_builder():
return SequenceOutput(
0, 0,
{0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
def completion_seq_group_output_builder():
return CompletionSequenceGroupOutput([], None)
# Used by pythonization to reduce python object allocations
class PythonizationCache:
def __init__(self):
self.cached_seq_output = PyObjectCache(seq_output_builder)
self.cached_completion_seq_group_output = PyObjectCache(
completion_seq_group_output_builder)
def reset(self):
self.cached_seq_output.reset()
self.cached_completion_seq_group_output.reset()
@dataclass
class ModelOutput:
"""The output of a single model forward pass.
@ -59,6 +83,7 @@ class ModelOutput:
pythonized: bool = False
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
pythonization_cache: Optional[PythonizationCache] = None
def pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
@ -97,7 +122,8 @@ class ModelOutput:
with torch.cuda.stream(copy_stream):
_pythonize_sampler_output(input_metadata, self.sampler_output,
pinned_sampled_token_buffer,
self.sampled_token_ids, self.logprobs)
self.sampled_token_ids, self.logprobs,
self.pythonization_cache)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
@ -209,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self._copy_stream = torch.cuda.Stream()
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
self.pythonization_cache = PythonizationCache()
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
@ -237,14 +265,22 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output_proc_callback: Callable):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
output_proc_callback()
cont = True
for model_output in model_input.cached_outputs:
if not model_output.pythonized:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if model_output.pythonized:
output_proc_callback(
sampler_output=model_output.sampler_output)
ctx = output_proc_callback.keywords["ctx"]
is_async = False
is_last_step = False
ctx.output_queue.append(
([model_output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
output_proc_callback()
else:
cont = False
@ -255,21 +291,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output_proc_callback: Optional[Callable]):
assert model_input.frozen_model_input is not None
has_async_callback = output_proc_callback is not None
outputs = []
for output_id in range(len(model_input.cached_outputs)):
is_last_output = output_id == len(model_input.cached_outputs) - 1
output = model_input.cached_outputs[output_id]
if not output.pythonized:
is_last_step = output_id == len(model_input.cached_outputs) - 1
# For non-async case:
# -- We simply add the outputs
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if has_async_callback:
assert output_proc_callback is not None
# Invoke callback before pythonize (to overlap with GPU)
output_proc_callback()
# Pythonize
if not output.pythonized:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for GPU overlap)
if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore
is_async = False
is_last_step = False
ctx.output_queue.append(
([output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
else:
outputs.append(output.sampler_output)
else:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if model_input.frozen_model_input.use_async_and_multi_step:
assert output_proc_callback is not None
output_proc_callback(sampler_output=output.sampler_output,
is_last_output=is_last_output)
outputs.append(output.sampler_output)
outputs.append(output.sampler_output)
return outputs
@ -330,7 +391,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input, model_input.cached_outputs[-1].sampler_output)
output_proc_callback = None
if frozen_model_input.use_async_and_multi_step:
if frozen_model_input.async_callback is not None:
output_proc_callback = frozen_model_input.async_callback
assert output_proc_callback is not None
async_callback = functools.partial(
@ -367,7 +428,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input.cached_outputs.append(
ModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False,
output[0].logprobs))
output[0].logprobs, self.pythonization_cache))
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
@ -378,7 +439,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output if CPU is ahead and the previous step is
# ready.
if not frozen_model_input.use_async_and_multi_step:
if frozen_model_input.async_callback is None:
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input,
self._copy_stream,
@ -397,6 +458,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if model_input.is_last_step:
outputs = self._final_process_outputs(model_input,
output_proc_callback)
self.pythonization_cache.reset()
return outputs
# should be [SamplerOutput]
@ -537,6 +599,7 @@ def _pythonize_sampler_output(
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor,
logprobs_tensor: Optional[torch.Tensor],
cache: Optional[PythonizationCache],
) -> None:
""" This function is only called when the output tensors are ready.
See :class:`ModelOutput`.
@ -597,6 +660,9 @@ def _pythonize_sampler_output(
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
if do_pythonize_logprobs:
assert prompt_logprobs is not None
@ -621,23 +687,56 @@ def _pythonize_sampler_output(
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
seq_outputs: List[SequenceOutput] = []
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
if cache is not None:
completion_seq_group_output: CompletionSequenceGroupOutput = \
cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear()
seq_outputs: List[
SequenceOutput] = completion_seq_group_output.samples
else:
seq_outputs = []
for tdx, (parent_id,
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx]
if logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs,
(group_prompt_logprobs if logprobs_are_requested else None)))
if cache is not None:
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
)
seq_output.parent_seq_id = seq_ids[parent_id]
seq_output.output_token = next_token_id
if logprobs_are_requested:
seq_output.logprobs = group_sample_logprobs[tdx]
else:
logprobs = next(iter(seq_output.logprobs.values()))
seq_output.logprobs.clear()
logprobs.logprob = float('inf')
logprobs.rank = None
logprobs.decoded_token = None
seq_output.logprobs[next_token_id] = logprobs
seq_outputs.append(seq_output)
else:
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx]
if logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
if cache is not None:
completion_seq_group_output.prompt_logprobs = \
group_prompt_logprobs if logprobs_are_requested else None
output.outputs.append(completion_seq_group_output)
else:
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs, (group_prompt_logprobs
if logprobs_are_requested else None)))
assert len(output.outputs) > 0

View File

@ -67,9 +67,7 @@ class MultiStepWorker(Worker):
if execute_model_req.async_callback:
model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=execute_model_req.async_callback,
use_async_and_multi_step=execute_model_req.
use_async_and_multi_step)
async_callback=execute_model_req.async_callback)
else:
# on subsequent steps we reuse the worker input and model input
multi_step_state = self.multi_step_states[virtual_engine]