[Core] Optimize Async + Multi-step (#8050)

This commit is contained in:
Alexander Matveev 2024-09-03 14:50:29 -04:00 committed by GitHub
parent 95a178f861
commit 6d646d08a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 325 additions and 247 deletions

View File

@ -103,13 +103,13 @@ async def test_multi_step(
model, model,
server_args + distributed_args, server_args + distributed_args,
num_logprobs, num_logprobs,
max_wait_seconds=3 * 240) max_wait_seconds=5 * 240)
test_completions = await completions_with_server_args( test_completions = await completions_with_server_args(
prompts, prompts,
model, model,
ms_server_args + distributed_args, ms_server_args + distributed_args,
num_logprobs, num_logprobs,
max_wait_seconds=3 * 240) max_wait_seconds=5 * 240)
# Assert multi-step scheduling produces identical tokens # Assert multi-step scheduling produces identical tokens
# to single-step scheduling. # to single-step scheduling.

View File

@ -280,40 +280,27 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs = cached_outputs.scheduler_outputs scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc allow_async_output_proc = cached_outputs.allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx = self.scheduler_contexts[virtual_engine] ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# skip the scheduler if there are any remaining steps in the seq groups. # skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current # This ensures that the scheduler is only called again when the current
# batch has completed. # batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list): if not self._has_remaining_steps(seq_group_metadata_list):
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()
# Schedule iteration # Schedule iteration
(seq_group_metadata_list, scheduler_outputs, (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc allow_async_output_proc
) = self.scheduler[virtual_engine].schedule() ) = self.scheduler[virtual_engine].schedule()
# Detect async + multi-step ctx.seq_group_metadata_list = seq_group_metadata_list
use_async_and_multi_step = (self.scheduler_config.is_multi_step ctx.scheduler_outputs = scheduler_outputs
and allow_async_output_proc)
# Maybe switch from async mode to sync mode # Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0: if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine, self._process_model_outputs(ctx=ctx)
is_async=True)
# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))
if (self.scheduler_config.is_multi_step if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0): and scheduler_outputs.num_lookahead_slots > 0):
@ -351,26 +338,20 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids=last_sampled_token_ids) last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc: if allow_async_output_proc:
async_callback = self.async_callback_multi_step[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] if use_async_and_multi_step \ virtual_engine]
else self.async_callback[virtual_engine]
execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step
# Execute the model. # Execute the model.
output = await self.model_executor.execute_model_async( output = await self.model_executor.execute_model_async(
execute_model_req) execute_model_req)
# we need to do this here so that last step's sampled_token_ids can # we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output) self._update_cached_scheduler_output(virtual_engine, output)
else: else:
if not use_async_and_multi_step and len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step self._process_model_outputs(ctx=ctx)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
output = [] output = []
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
@ -384,24 +365,22 @@ class _AsyncLLMEngine(LLMEngine):
self.cached_scheduler_outputs[ self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState() virtual_engine] = SchedulerOutputState()
if use_async_and_multi_step: is_async = allow_async_output_proc
# For async + multi-step, clear the queue is_last_step = True
ctx.output_queue.clear() ctx.output_queue.append(
else: (output, seq_group_metadata_list, scheduler_outputs, is_async,
ctx.output_queue.append( is_last_step))
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc: if output and allow_async_output_proc:
assert len( assert len(
output output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 ) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step( self._advance_to_next_step(
output[0], seq_group_metadata_list, output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc: if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine, self._process_model_outputs(ctx=ctx)
is_async=False)
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
@ -411,17 +390,12 @@ class _AsyncLLMEngine(LLMEngine):
else: else:
# Multi-step case # Multi-step case
if use_async_and_multi_step: return ctx.request_outputs
return []
else:
ctx.request_outputs = []
if not self.has_unfinished_requests(): if not self.has_unfinished_requests():
# Drain async postprocessor (if exists) # Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step self._process_model_outputs(ctx=ctx)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
assert len(ctx.output_queue) == 0 assert len(ctx.output_queue) == 0
return ctx.request_outputs return ctx.request_outputs
@ -640,6 +614,17 @@ class AsyncLLMEngine:
self.log_requests = log_requests self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs) self.engine = self._init_engine(*args, **kwargs)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self.use_process_request_outputs_callback = not self.engine_use_ray
if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
self.process_request_outputs
if self.engine_use_ray: if self.engine_use_ray:
print_warning_once( print_warning_once(
"DEPRECATED. `--engine-use-ray` is deprecated and will " "DEPRECATED. `--engine-use-ray` is deprecated and will "
@ -883,13 +868,27 @@ class AsyncLLMEngine:
request_outputs = await self.engine.step_async(virtual_engine) request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
finished = True # If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
if not self.use_process_request_outputs_callback:
all_finished = self.process_request_outputs(request_outputs)
else:
# For callback case, we only need to detect when all
# requests are finished
all_finished = all(request_output.finished
for request_output in request_outputs)
return not all_finished
def process_request_outputs(self, request_outputs) -> bool:
# Put the outputs into the corresponding streams.
all_finished = True
for request_output in request_outputs: for request_output in request_outputs:
self._request_tracker.process_request_output( self._request_tracker.process_request_output(
request_output, verbose=self.log_requests) request_output, verbose=self.log_requests)
finished = finished and request_output.finished all_finished = all_finished and request_output.finished
return not finished return all_finished
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray: if self.engine_use_ray:

View File

@ -93,13 +93,14 @@ class SchedulerOutputState:
@dataclass @dataclass
class SchedulerContext: class SchedulerContext:
output_queue: Deque[Tuple[Optional[List[SamplerOutput]], output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
List[SequenceGroupMetadata], List[SequenceGroupMetadata], SchedulerOutputs,
SchedulerOutputs]] = field( bool,
default_factory=lambda: deque()) bool]] = field(default_factory=lambda: deque())
request_outputs: List[Union[RequestOutput, request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = field( EmbeddingRequestOutput]] = field(
default_factory=lambda: []) default_factory=lambda: [])
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
class LLMEngine: class LLMEngine:
@ -357,6 +358,26 @@ class LLMEngine:
# different process. # different process.
self.tokenizer.ping() self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callbacks = [
functools.partial(self._process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback = None
# Create the scheduler. # Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of # NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor. # GPU and CPU blocks, which are profiled in the distributed executor.
@ -364,9 +385,7 @@ class LLMEngine:
Scheduler( Scheduler(
scheduler_config, cache_config, lora_config, scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size,
functools.partial(self._process_model_outputs, self.async_callbacks[v_id]
virtual_engine=v_id,
is_async=True)
if model_config.use_async_output_proc else None) if model_config.use_async_output_proc else None)
for v_id in range(parallel_config.pipeline_parallel_size) for v_id in range(parallel_config.pipeline_parallel_size)
] ]
@ -417,30 +436,6 @@ class LLMEngine:
), ),
)) ))
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callback = [
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=True)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callback_multi_step = [
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=False)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
@ -1249,11 +1244,7 @@ class LLMEngine:
return return
def _process_model_outputs(self, def _process_model_outputs(self, ctx: SchedulerContext) -> None:
virtual_engine: int,
is_async: bool,
sampler_output: Optional[SamplerOutput] = None,
is_last_output: bool = False) -> None:
"""Apply the model output to the sequences in the scheduled seq groups. """Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on virtual_engine: The engine id to operate on
@ -1273,24 +1264,12 @@ class LLMEngine:
""" """
now = time.time() now = time.time()
is_multi_step = sampler_output is not None
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
if len(ctx.output_queue) == 0: if len(ctx.output_queue) == 0:
return None return None
if is_multi_step: # Get pending async postprocessor
# Async + multi-step case (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
(outputs, seq_group_metadata_list, is_last_step) = ctx.output_queue.popleft()
scheduler_outputs) = ctx.output_queue[0]
assert outputs is None
outputs = [sampler_output]
else:
# Async standard case
(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue.popleft()
assert outputs is not None assert outputs is not None
# Sanity check # Sanity check
@ -1306,6 +1285,7 @@ class LLMEngine:
outputs_by_sequence_group = outputs outputs_by_sequence_group = outputs
finished_before: List[int] = [] finished_before: List[int] = []
finished_now: List[int] = []
for i, seq_group_meta in enumerate(seq_group_metadata_list): for i, seq_group_meta in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
@ -1343,26 +1323,44 @@ class LLMEngine:
if self.model_config.embedding_mode: if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, output) self._process_sequence_group_outputs(seq_group, output)
continue else:
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(
seq_group, output, is_async)
self.output_processor.process_prompt_logprob(seq_group, output) if seq_group.is_finished():
if seq_group_meta.do_sample: finished_now.append(i)
self.output_processor.process_outputs(seq_group, output,
is_async)
# For async + multi-step, free finished seqs and create outputs # Generate outputs for the requests that finished this iteration
# only on the final step. for i in finished_now:
if is_multi_step and not is_last_output:
return
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# Create the outputs.
for i, _ in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
if not is_multi_step and i in finished_before: seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
# Free currently finished requests
if finished_now:
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# For multi-step, do not create outputs each iteration
if not is_last_step:
# Immediately process request outputs here (if callback is given)
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
return
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
# must match with the indices
for i, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
if i in finished_before or i in finished_now:
continue # Avoids double processing continue # Avoids double processing
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
@ -1376,11 +1374,15 @@ class LLMEngine:
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
# For async + multi-step, do stats only on the last output. # Immediately process request outputs here (if callback is given)
# Otherwise, do stats if the execution is async if (ctx.request_outputs
do_stats = is_multi_step or is_async and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
if do_stats: # For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if is_async:
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, outputs, finished_before) self.do_log_stats(scheduler_outputs, outputs, finished_before)
@ -1485,40 +1487,26 @@ class LLMEngine:
scheduler_outputs = cached_outputs.scheduler_outputs scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc allow_async_output_proc = cached_outputs.allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx = self.scheduler_contexts[virtual_engine] ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# Skip the scheduler if there are any remaining steps in the seq groups. # Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current # This ensures that the scheduler is only called again when the current
# batch has completed. # batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list): if not self._has_remaining_steps(seq_group_metadata_list):
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()
# Schedule iteration # Schedule iteration
(seq_group_metadata_list, scheduler_outputs, (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc allow_async_output_proc
) = self.scheduler[virtual_engine].schedule() ) = self.scheduler[virtual_engine].schedule()
# Detect async + multi-step ctx.seq_group_metadata_list = seq_group_metadata_list
use_async_and_multi_step = (self.scheduler_config.is_multi_step ctx.scheduler_outputs = scheduler_outputs
and allow_async_output_proc)
# Maybe switch from async mode to sync mode # Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0: if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine, self._process_model_outputs(ctx=ctx)
is_async=True)
# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))
if (self.scheduler_config.is_multi_step if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0): and scheduler_outputs.num_lookahead_slots > 0):
@ -1555,13 +1543,8 @@ class LLMEngine:
last_sampled_token_ids=last_sampled_token_ids) last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc: if allow_async_output_proc:
async_callback = self.async_callback_multi_step[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] if use_async_and_multi_step \ virtual_engine]
else self.async_callback[virtual_engine]
execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step
output = self.model_executor.execute_model( output = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
@ -1573,10 +1556,8 @@ class LLMEngine:
else: else:
# Nothing scheduled => If there is pending async postprocessor, # Nothing scheduled => If there is pending async postprocessor,
# then finish it here. # then finish it here.
if not use_async_and_multi_step and len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step self._process_model_outputs(ctx=ctx)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
# No outputs in this case # No outputs in this case
output = [] output = []
@ -1590,28 +1571,24 @@ class LLMEngine:
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState() self.cached_scheduler_outputs[0] = SchedulerOutputState()
if use_async_and_multi_step: # Add results to the output_queue
# For async + multi-step, clear the queue is_async = allow_async_output_proc
ctx.output_queue.clear() is_last_step = True
else: ctx.output_queue.append(
# Add results to the output_queue (output, seq_group_metadata_list, scheduler_outputs, is_async,
# (for async or non-async postprocessing) is_last_step))
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc: if output and allow_async_output_proc:
assert len(output) == 1, ( assert len(output) == 1, (
"Multi step decoding does not work " "Async postprocessor expects only a single output set")
"with async output processing.")
self._advance_to_next_step( self._advance_to_next_step(
output[0], seq_group_metadata_list, output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
# Check if need to run the usual non-async path # Check if need to run the usual non-async path
if not allow_async_output_proc: if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine, self._process_model_outputs(ctx=ctx)
is_async=False)
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
@ -1620,17 +1597,12 @@ class LLMEngine:
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs)
else: else:
# Multi-step case # Multi-step case
if use_async_and_multi_step: return ctx.request_outputs
return []
else:
ctx.request_outputs = []
if not self.has_unfinished_requests(): if not self.has_unfinished_requests():
# Drain async postprocessor (if exists) # Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step self._process_model_outputs(ctx=ctx)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
assert len(ctx.output_queue) == 0 assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are # Stop the execute model loop in parallel workers until there are

View File

@ -85,9 +85,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
no tokens need to be appended since it is already done no tokens need to be appended since it is already done
externally (before the next schedule() call) externally (before the next schedule() call)
""" """
# TODO: Add support for async if necessary
assert not is_async
# Sequences can be in RUNNING or FINISHED_ABORTED state # Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED # once scheduled, as a sequence is moved to FINSIHED_ABORTED
# if a client disconnects from the api server. # if a client disconnects from the api server.
@ -101,19 +98,41 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Beam search not supported in multi-step decoding.") "Beam search not supported in multi-step decoding.")
seq = seqs[0] seq = seqs[0]
# Since there's only one sequence per sequence group, we can take the if is_async:
# first sample. # Async case: We process tokens one by one. Here, we know the token
samples = [output.samples[0] for output in outputs] # was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
self._process_decode_and_stop(seq, sequence_group.sampling_params)
else:
# Standard multi-step case
# -1 means the output token is not valid (eg. due to spec decode # Since there's only one sequence per sequence group,
# rejecting tokens). # we can take the first sample.
valid_samples = [ samples = [output.samples[0] for output in outputs]
sample for sample in samples if sample.output_token != -1
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples, # -1 means the output token is not valid (eg. due to spec decode
sequence_group.sampling_params) # rejecting tokens).
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
new_char_count = 0
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params,
)
def _process_seq_outputs(self, seq: Sequence, def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput], valid_samples: List[SequenceOutput],
@ -151,16 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs=output_logprob, logprobs=output_logprob,
) )
new_char_count = 0 self._process_decode_and_stop(seq, sampling_params)
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params,
)
if seq.is_finished(): if seq.is_finished():
break break

View File

@ -1225,7 +1225,6 @@ class ExecuteModelRequest(
last_sampled_token_ids: Optional[torch.Tensor] = None last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback # Async callback
async_callback: Optional[Callable] = None async_callback: Optional[Callable] = None
use_async_and_multi_step: bool = False
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
@ -1272,5 +1271,4 @@ class ExecuteModelRequest(
finished_requests_ids=self.finished_requests_ids, finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone() last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None, if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback, async_callback=self.async_callback)
use_async_and_multi_step=self.use_async_and_multi_step)

View File

@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
@ -96,7 +97,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0 virtual_engine: int = 0
async_callback: Optional[Callable] = None async_callback: Optional[Callable] = None
use_async_and_multi_step: bool = False seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
get_pythonized_sample_results) get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput) Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import PyObjectCache
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
@ -37,6 +38,29 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
def seq_output_builder():
return SequenceOutput(
0, 0,
{0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
def completion_seq_group_output_builder():
return CompletionSequenceGroupOutput([], None)
# Used by pythonization to reduce python object allocations
class PythonizationCache:
def __init__(self):
self.cached_seq_output = PyObjectCache(seq_output_builder)
self.cached_completion_seq_group_output = PyObjectCache(
completion_seq_group_output_builder)
def reset(self):
self.cached_seq_output.reset()
self.cached_completion_seq_group_output.reset()
@dataclass @dataclass
class ModelOutput: class ModelOutput:
"""The output of a single model forward pass. """The output of a single model forward pass.
@ -59,6 +83,7 @@ class ModelOutput:
pythonized: bool = False pythonized: bool = False
# On-device tensor containing the logprobs of each token. # On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None logprobs: Optional["torch.Tensor"] = None
pythonization_cache: Optional[PythonizationCache] = None
def pythonize(self, input_metadata: "StatefulModelInput", def pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream, copy_stream: torch.cuda.Stream,
@ -97,7 +122,8 @@ class ModelOutput:
with torch.cuda.stream(copy_stream): with torch.cuda.stream(copy_stream):
_pythonize_sampler_output(input_metadata, self.sampler_output, _pythonize_sampler_output(input_metadata, self.sampler_output,
pinned_sampled_token_buffer, pinned_sampled_token_buffer,
self.sampled_token_ids, self.logprobs) self.sampled_token_ids, self.logprobs,
self.pythonization_cache)
# Erase the logprobs GPU-side tensor. # Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its # Note that although _pythonize_sampler_output() runs in its
@ -209,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self._copy_stream = torch.cuda.Stream() self._copy_stream = torch.cuda.Stream()
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
self.pythonization_cache = PythonizationCache()
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
model_input = (StatefulModelInput.from_broadcasted_tensor_dict( model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
@ -237,14 +265,22 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output_proc_callback: Callable): output_proc_callback: Callable):
# Proceed with pythonization and output_proc in order. # Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize # Stop on the first one that fails to pythonize
output_proc_callback()
cont = True cont = True
for model_output in model_input.cached_outputs: for model_output in model_input.cached_outputs:
if not model_output.pythonized: if not model_output.pythonized:
model_output.maybe_pythonize(model_input, self._copy_stream, model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids) self.pinned_sampled_token_ids)
if model_output.pythonized: if model_output.pythonized:
output_proc_callback( ctx = output_proc_callback.keywords["ctx"]
sampler_output=model_output.sampler_output) is_async = False
is_last_step = False
ctx.output_queue.append(
([model_output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
output_proc_callback()
else: else:
cont = False cont = False
@ -255,21 +291,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output_proc_callback: Optional[Callable]): output_proc_callback: Optional[Callable]):
assert model_input.frozen_model_input is not None assert model_input.frozen_model_input is not None
has_async_callback = output_proc_callback is not None
outputs = [] outputs = []
for output_id in range(len(model_input.cached_outputs)): for output_id in range(len(model_input.cached_outputs)):
is_last_output = output_id == len(model_input.cached_outputs) - 1
output = model_input.cached_outputs[output_id] output = model_input.cached_outputs[output_id]
if not output.pythonized: is_last_step = output_id == len(model_input.cached_outputs) - 1
# For non-async case:
# -- We simply add the outputs
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if has_async_callback:
assert output_proc_callback is not None
# Invoke callback before pythonize (to overlap with GPU)
output_proc_callback()
# Pythonize
if not output.pythonized:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for GPU overlap)
if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore
is_async = False
is_last_step = False
ctx.output_queue.append(
([output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
else:
outputs.append(output.sampler_output)
else:
output.pythonize(model_input, self._copy_stream, output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids) self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
if model_input.frozen_model_input.use_async_and_multi_step:
assert output_proc_callback is not None
output_proc_callback(sampler_output=output.sampler_output,
is_last_output=is_last_output)
outputs.append(output.sampler_output)
return outputs return outputs
@ -330,7 +391,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input, model_input.cached_outputs[-1].sampler_output) model_input, model_input.cached_outputs[-1].sampler_output)
output_proc_callback = None output_proc_callback = None
if frozen_model_input.use_async_and_multi_step: if frozen_model_input.async_callback is not None:
output_proc_callback = frozen_model_input.async_callback output_proc_callback = frozen_model_input.async_callback
assert output_proc_callback is not None assert output_proc_callback is not None
async_callback = functools.partial( async_callback = functools.partial(
@ -367,7 +428,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input.cached_outputs.append( model_input.cached_outputs.append(
ModelOutput(output[0], output_ready_event, ModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False, output[0].sampled_token_ids, False,
output[0].logprobs)) output[0].logprobs, self.pythonization_cache))
# These GPU tensors are not required by multi-step; # These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or # erase them to ensure they are not pythonized or
@ -378,7 +439,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output if CPU is ahead and the previous step is # Pythonize the output if CPU is ahead and the previous step is
# ready. # ready.
if not frozen_model_input.use_async_and_multi_step: if frozen_model_input.async_callback is None:
for model_output in model_input.cached_outputs: for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input, model_output.maybe_pythonize(model_input,
self._copy_stream, self._copy_stream,
@ -397,6 +458,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if model_input.is_last_step: if model_input.is_last_step:
outputs = self._final_process_outputs(model_input, outputs = self._final_process_outputs(model_input,
output_proc_callback) output_proc_callback)
self.pythonization_cache.reset()
return outputs return outputs
# should be [SamplerOutput] # should be [SamplerOutput]
@ -537,6 +599,7 @@ def _pythonize_sampler_output(
pinned_sampled_token_buffer: torch.Tensor, pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
logprobs_tensor: Optional[torch.Tensor], logprobs_tensor: Optional[torch.Tensor],
cache: Optional[PythonizationCache],
) -> None: ) -> None:
""" This function is only called when the output tensors are ready. """ This function is only called when the output tensors are ready.
See :class:`ModelOutput`. See :class:`ModelOutput`.
@ -597,6 +660,9 @@ def _pythonize_sampler_output(
for sgdx, (seq_group, for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)): sample_result) in enumerate(zip(seq_groups, samples_list)):
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
if do_pythonize_logprobs: if do_pythonize_logprobs:
assert prompt_logprobs is not None assert prompt_logprobs is not None
@ -621,23 +687,56 @@ def _pythonize_sampler_output(
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
next_token_ids = sample_result next_token_ids = sample_result
parent_ids = [0] parent_ids = [0]
seq_outputs: List[SequenceOutput] = []
if seq_group.sampling_params.logits_processors: if cache is not None:
assert len(seq_group.sampling_params.logits_processors) == 0, ( completion_seq_group_output: CompletionSequenceGroupOutput = \
"Logits Processors are not supported in multi-step decoding") cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear()
seq_outputs: List[
SequenceOutput] = completion_seq_group_output.samples
else:
seq_outputs = []
for tdx, (parent_id, for tdx, (parent_id,
next_token_id) in enumerate(zip(parent_ids, next_token_ids)): next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
seq_outputs.append( if cache is not None:
SequenceOutput(seq_ids[parent_id], next_token_id, seq_output: SequenceOutput = cache.cached_seq_output.get_object(
(group_sample_logprobs[tdx] )
if logprobs_are_requested else { seq_output.parent_seq_id = seq_ids[parent_id]
next_token_id: seq_output.output_token = next_token_id
Logprob(logprob=float('inf'),
rank=None, if logprobs_are_requested:
decoded_token=None) seq_output.logprobs = group_sample_logprobs[tdx]
}))) else:
output.outputs.append( logprobs = next(iter(seq_output.logprobs.values()))
CompletionSequenceGroupOutput( seq_output.logprobs.clear()
seq_outputs,
(group_prompt_logprobs if logprobs_are_requested else None))) logprobs.logprob = float('inf')
logprobs.rank = None
logprobs.decoded_token = None
seq_output.logprobs[next_token_id] = logprobs
seq_outputs.append(seq_output)
else:
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx]
if logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
if cache is not None:
completion_seq_group_output.prompt_logprobs = \
group_prompt_logprobs if logprobs_are_requested else None
output.outputs.append(completion_seq_group_output)
else:
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs, (group_prompt_logprobs
if logprobs_are_requested else None)))
assert len(output.outputs) > 0 assert len(output.outputs) > 0

View File

@ -67,9 +67,7 @@ class MultiStepWorker(Worker):
if execute_model_req.async_callback: if execute_model_req.async_callback:
model_input.frozen_model_input = dataclasses.replace( # type: ignore model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input, model_input.frozen_model_input,
async_callback=execute_model_req.async_callback, async_callback=execute_model_req.async_callback)
use_async_and_multi_step=execute_model_req.
use_async_and_multi_step)
else: else:
# on subsequent steps we reuse the worker input and model input # on subsequent steps we reuse the worker input and model input
multi_step_state = self.multi_step_states[virtual_engine] multi_step_state = self.multi_step_states[virtual_engine]