[Core] Optimize Async + Multi-step (#8050)
This commit is contained in:
parent
95a178f861
commit
6d646d08a2
@ -103,13 +103,13 @@ async def test_multi_step(
|
||||
model,
|
||||
server_args + distributed_args,
|
||||
num_logprobs,
|
||||
max_wait_seconds=3 * 240)
|
||||
max_wait_seconds=5 * 240)
|
||||
test_completions = await completions_with_server_args(
|
||||
prompts,
|
||||
model,
|
||||
ms_server_args + distributed_args,
|
||||
num_logprobs,
|
||||
max_wait_seconds=3 * 240)
|
||||
max_wait_seconds=5 * 240)
|
||||
|
||||
# Assert multi-step scheduling produces identical tokens
|
||||
# to single-step scheduling.
|
||||
|
||||
@ -280,40 +280,27 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
# Clear outputs for each new scheduler iteration
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
|
||||
# Clear outputs on scheduler iteration start
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# Schedule iteration
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
|
||||
# For async + multi-step, init the queue
|
||||
if use_async_and_multi_step:
|
||||
assert len(ctx.output_queue) == 0
|
||||
assert seq_group_metadata_list is not None
|
||||
ctx.output_queue.append(
|
||||
(None, seq_group_metadata_list, scheduler_outputs))
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
@ -351,26 +338,20 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
async_callback = self.async_callback_multi_step[
|
||||
virtual_engine] if use_async_and_multi_step \
|
||||
else self.async_callback[virtual_engine]
|
||||
|
||||
execute_model_req.async_callback = async_callback
|
||||
execute_model_req.use_async_and_multi_step = \
|
||||
use_async_and_multi_step
|
||||
execute_model_req.async_callback = self.async_callbacks[
|
||||
virtual_engine]
|
||||
|
||||
# Execute the model.
|
||||
output = await self.model_executor.execute_model_async(
|
||||
execute_model_req)
|
||||
|
||||
# we need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
else:
|
||||
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
output = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
@ -384,24 +365,22 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
|
||||
if use_async_and_multi_step:
|
||||
# For async + multi-step, clear the queue
|
||||
ctx.output_queue.clear()
|
||||
else:
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
is_async = allow_async_output_proc
|
||||
is_last_step = True
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step))
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
assert len(
|
||||
output
|
||||
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
if output and allow_async_output_proc:
|
||||
assert len(
|
||||
output
|
||||
) == 1, "Async postprocessor expects only a single output set"
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=False)
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
@ -411,17 +390,12 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
else:
|
||||
# Multi-step case
|
||||
if use_async_and_multi_step:
|
||||
return []
|
||||
else:
|
||||
ctx.request_outputs = []
|
||||
return ctx.request_outputs
|
||||
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor (if exists)
|
||||
if len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
assert len(ctx.output_queue) == 0
|
||||
|
||||
return ctx.request_outputs
|
||||
@ -640,6 +614,17 @@ class AsyncLLMEngine:
|
||||
self.log_requests = log_requests
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
# This ensures quick processing of request outputs
|
||||
# so the append to asyncio queues is not delayed,
|
||||
# especially for multi-step.
|
||||
#
|
||||
# TODO: Currently, disabled for engine_use_ray, ask
|
||||
# Cody/Will/Woosuk about this case.
|
||||
self.use_process_request_outputs_callback = not self.engine_use_ray
|
||||
if self.use_process_request_outputs_callback:
|
||||
self.engine.process_request_outputs_callback = \
|
||||
self.process_request_outputs
|
||||
|
||||
if self.engine_use_ray:
|
||||
print_warning_once(
|
||||
"DEPRECATED. `--engine-use-ray` is deprecated and will "
|
||||
@ -883,13 +868,27 @@ class AsyncLLMEngine:
|
||||
request_outputs = await self.engine.step_async(virtual_engine)
|
||||
|
||||
# Put the outputs into the corresponding streams.
|
||||
finished = True
|
||||
# If used as a callback, then already invoked inside
|
||||
# LLMEngine's _process_model_outputs
|
||||
if not self.use_process_request_outputs_callback:
|
||||
all_finished = self.process_request_outputs(request_outputs)
|
||||
else:
|
||||
# For callback case, we only need to detect when all
|
||||
# requests are finished
|
||||
all_finished = all(request_output.finished
|
||||
for request_output in request_outputs)
|
||||
|
||||
return not all_finished
|
||||
|
||||
def process_request_outputs(self, request_outputs) -> bool:
|
||||
# Put the outputs into the corresponding streams.
|
||||
all_finished = True
|
||||
for request_output in request_outputs:
|
||||
self._request_tracker.process_request_output(
|
||||
request_output, verbose=self.log_requests)
|
||||
finished = finished and request_output.finished
|
||||
all_finished = all_finished and request_output.finished
|
||||
|
||||
return not finished
|
||||
return all_finished
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
|
||||
@ -93,13 +93,14 @@ class SchedulerOutputState:
|
||||
@dataclass
|
||||
class SchedulerContext:
|
||||
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
|
||||
List[SequenceGroupMetadata],
|
||||
SchedulerOutputs]] = field(
|
||||
default_factory=lambda: deque())
|
||||
|
||||
List[SequenceGroupMetadata], SchedulerOutputs,
|
||||
bool,
|
||||
bool]] = field(default_factory=lambda: deque())
|
||||
request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = field(
|
||||
default_factory=lambda: [])
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
@ -357,6 +358,26 @@ class LLMEngine:
|
||||
# different process.
|
||||
self.tokenizer.ping()
|
||||
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.scheduler_contexts = [
|
||||
SchedulerContext()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.async_callbacks = [
|
||||
functools.partial(self._process_model_outputs,
|
||||
ctx=self.scheduler_contexts[v_id])
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Currently used by AsyncLLMEngine to ensure quick append
|
||||
# of request outputs to asyncio queues
|
||||
self.process_request_outputs_callback = None
|
||||
|
||||
# Create the scheduler.
|
||||
# NOTE: the cache_config here have been updated with the numbers of
|
||||
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||
@ -364,9 +385,7 @@ class LLMEngine:
|
||||
Scheduler(
|
||||
scheduler_config, cache_config, lora_config,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
functools.partial(self._process_model_outputs,
|
||||
virtual_engine=v_id,
|
||||
is_async=True)
|
||||
self.async_callbacks[v_id]
|
||||
if model_config.use_async_output_proc else None)
|
||||
for v_id in range(parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
@ -417,30 +436,6 @@ class LLMEngine:
|
||||
),
|
||||
))
|
||||
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.scheduler_contexts = [
|
||||
SchedulerContext()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.async_callback = [
|
||||
functools.partial(self._process_model_outputs,
|
||||
virtual_engine=v_id,
|
||||
is_async=True)
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.async_callback_multi_step = [
|
||||
functools.partial(self._process_model_outputs,
|
||||
virtual_engine=v_id,
|
||||
is_async=False)
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@ -1249,11 +1244,7 @@ class LLMEngine:
|
||||
|
||||
return
|
||||
|
||||
def _process_model_outputs(self,
|
||||
virtual_engine: int,
|
||||
is_async: bool,
|
||||
sampler_output: Optional[SamplerOutput] = None,
|
||||
is_last_output: bool = False) -> None:
|
||||
def _process_model_outputs(self, ctx: SchedulerContext) -> None:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
virtual_engine: The engine id to operate on
|
||||
@ -1273,24 +1264,12 @@ class LLMEngine:
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
is_multi_step = sampler_output is not None
|
||||
|
||||
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
if len(ctx.output_queue) == 0:
|
||||
return None
|
||||
|
||||
if is_multi_step:
|
||||
# Async + multi-step case
|
||||
(outputs, seq_group_metadata_list,
|
||||
scheduler_outputs) = ctx.output_queue[0]
|
||||
assert outputs is None
|
||||
outputs = [sampler_output]
|
||||
else:
|
||||
# Async standard case
|
||||
(outputs, seq_group_metadata_list,
|
||||
scheduler_outputs) = ctx.output_queue.popleft()
|
||||
|
||||
# Get pending async postprocessor
|
||||
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step) = ctx.output_queue.popleft()
|
||||
assert outputs is not None
|
||||
|
||||
# Sanity check
|
||||
@ -1306,6 +1285,7 @@ class LLMEngine:
|
||||
outputs_by_sequence_group = outputs
|
||||
|
||||
finished_before: List[int] = []
|
||||
finished_now: List[int] = []
|
||||
for i, seq_group_meta in enumerate(seq_group_metadata_list):
|
||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||
|
||||
@ -1343,26 +1323,44 @@ class LLMEngine:
|
||||
|
||||
if self.model_config.embedding_mode:
|
||||
self._process_sequence_group_outputs(seq_group, output)
|
||||
continue
|
||||
else:
|
||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||
if seq_group_meta.do_sample:
|
||||
self.output_processor.process_outputs(
|
||||
seq_group, output, is_async)
|
||||
|
||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||
if seq_group_meta.do_sample:
|
||||
self.output_processor.process_outputs(seq_group, output,
|
||||
is_async)
|
||||
if seq_group.is_finished():
|
||||
finished_now.append(i)
|
||||
|
||||
# For async + multi-step, free finished seqs and create outputs
|
||||
# only on the final step.
|
||||
if is_multi_step and not is_last_output:
|
||||
return
|
||||
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
for i, _ in enumerate(seq_group_metadata_list):
|
||||
# Generate outputs for the requests that finished this iteration
|
||||
for i in finished_now:
|
||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||
|
||||
if not is_multi_step and i in finished_before:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
# Free currently finished requests
|
||||
if finished_now:
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
# For multi-step, do not create outputs each iteration
|
||||
if not is_last_step:
|
||||
# Immediately process request outputs here (if callback is given)
|
||||
if (finished_now
|
||||
and self.process_request_outputs_callback is not None):
|
||||
self.process_request_outputs_callback(ctx.request_outputs)
|
||||
return
|
||||
|
||||
# Create the outputs
|
||||
# Note: scheduled_seq_groups and seq_group_metadata_list
|
||||
# must match with the indices
|
||||
for i, scheduled_seq_group in enumerate(
|
||||
scheduler_outputs.scheduled_seq_groups):
|
||||
|
||||
if i in finished_before or i in finished_now:
|
||||
continue # Avoids double processing
|
||||
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
@ -1376,11 +1374,15 @@ class LLMEngine:
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
# For async + multi-step, do stats only on the last output.
|
||||
# Otherwise, do stats if the execution is async
|
||||
do_stats = is_multi_step or is_async
|
||||
# Immediately process request outputs here (if callback is given)
|
||||
if (ctx.request_outputs
|
||||
and self.process_request_outputs_callback is not None):
|
||||
self.process_request_outputs_callback(ctx.request_outputs)
|
||||
|
||||
if do_stats:
|
||||
# For async case, we need to record the stats here.
|
||||
# For non-async case, the stats are done in the
|
||||
# LLMEngine/AsyncLLMEngine directly
|
||||
if is_async:
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, outputs, finished_before)
|
||||
|
||||
@ -1485,40 +1487,26 @@ class LLMEngine:
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
# Clear outputs for each new scheduler iteration
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
|
||||
# Clear outputs on scheduler iteration start
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# Schedule iteration
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
|
||||
# For async + multi-step, init the queue
|
||||
if use_async_and_multi_step:
|
||||
assert len(ctx.output_queue) == 0
|
||||
assert seq_group_metadata_list is not None
|
||||
ctx.output_queue.append(
|
||||
(None, seq_group_metadata_list, scheduler_outputs))
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
@ -1555,13 +1543,8 @@ class LLMEngine:
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
async_callback = self.async_callback_multi_step[
|
||||
virtual_engine] if use_async_and_multi_step \
|
||||
else self.async_callback[virtual_engine]
|
||||
|
||||
execute_model_req.async_callback = async_callback
|
||||
execute_model_req.use_async_and_multi_step = \
|
||||
use_async_and_multi_step
|
||||
execute_model_req.async_callback = self.async_callbacks[
|
||||
virtual_engine]
|
||||
|
||||
output = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
@ -1573,10 +1556,8 @@ class LLMEngine:
|
||||
else:
|
||||
# Nothing scheduled => If there is pending async postprocessor,
|
||||
# then finish it here.
|
||||
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
# No outputs in this case
|
||||
output = []
|
||||
|
||||
@ -1590,28 +1571,24 @@ class LLMEngine:
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||
|
||||
if use_async_and_multi_step:
|
||||
# For async + multi-step, clear the queue
|
||||
ctx.output_queue.clear()
|
||||
else:
|
||||
# Add results to the output_queue
|
||||
# (for async or non-async postprocessing)
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
# Add results to the output_queue
|
||||
is_async = allow_async_output_proc
|
||||
is_last_step = True
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step))
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
assert len(output) == 1, (
|
||||
"Multi step decoding does not work "
|
||||
"with async output processing.")
|
||||
if output and allow_async_output_proc:
|
||||
assert len(output) == 1, (
|
||||
"Async postprocessor expects only a single output set")
|
||||
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
# Check if need to run the usual non-async path
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=False)
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
@ -1620,17 +1597,12 @@ class LLMEngine:
|
||||
self.do_tracing(scheduler_outputs)
|
||||
else:
|
||||
# Multi-step case
|
||||
if use_async_and_multi_step:
|
||||
return []
|
||||
else:
|
||||
ctx.request_outputs = []
|
||||
return ctx.request_outputs
|
||||
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor (if exists)
|
||||
if len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
assert len(ctx.output_queue) == 0
|
||||
|
||||
# Stop the execute model loop in parallel workers until there are
|
||||
|
||||
@ -85,9 +85,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
"""
|
||||
# TODO: Add support for async if necessary
|
||||
assert not is_async
|
||||
|
||||
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
||||
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
||||
# if a client disconnects from the api server.
|
||||
@ -101,19 +98,41 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"Beam search not supported in multi-step decoding.")
|
||||
seq = seqs[0]
|
||||
|
||||
# Since there's only one sequence per sequence group, we can take the
|
||||
# first sample.
|
||||
samples = [output.samples[0] for output in outputs]
|
||||
if is_async:
|
||||
# Async case: We process tokens one by one. Here, we know the token
|
||||
# was already appended, so we only need to do the rest of the
|
||||
# postprocessor: Detokenization + stopping logic
|
||||
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
||||
else:
|
||||
# Standard multi-step case
|
||||
|
||||
# -1 means the output token is not valid (eg. due to spec decode
|
||||
# rejecting tokens).
|
||||
valid_samples = [
|
||||
sample for sample in samples if sample.output_token != -1
|
||||
]
|
||||
assert valid_samples
|
||||
# Since there's only one sequence per sequence group,
|
||||
# we can take the first sample.
|
||||
samples = [output.samples[0] for output in outputs]
|
||||
|
||||
self._process_seq_outputs(seq, valid_samples,
|
||||
sequence_group.sampling_params)
|
||||
# -1 means the output token is not valid (eg. due to spec decode
|
||||
# rejecting tokens).
|
||||
valid_samples = [
|
||||
sample for sample in samples if sample.output_token != -1
|
||||
]
|
||||
assert valid_samples
|
||||
|
||||
self._process_seq_outputs(seq, valid_samples,
|
||||
sequence_group.sampling_params)
|
||||
|
||||
def _process_decode_and_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
new_char_count = 0
|
||||
if sampling_params.detokenize:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
|
||||
# TODO(sang): Support lora.
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count=new_char_count,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
def _process_seq_outputs(self, seq: Sequence,
|
||||
valid_samples: List[SequenceOutput],
|
||||
@ -151,16 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
logprobs=output_logprob,
|
||||
)
|
||||
|
||||
new_char_count = 0
|
||||
if sampling_params.detokenize:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
self._process_decode_and_stop(seq, sampling_params)
|
||||
|
||||
# TODO(sang): Support lora.
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count=new_char_count,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
if seq.is_finished():
|
||||
break
|
||||
|
||||
@ -1225,7 +1225,6 @@ class ExecuteModelRequest(
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
# Async callback
|
||||
async_callback: Optional[Callable] = None
|
||||
use_async_and_multi_step: bool = False
|
||||
|
||||
@property
|
||||
def is_first_multi_step(self) -> bool:
|
||||
@ -1272,5 +1271,4 @@ class ExecuteModelRequest(
|
||||
finished_requests_ids=self.finished_requests_ids,
|
||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||
if self.last_sampled_token_ids is not None else None,
|
||||
async_callback=self.async_callback,
|
||||
use_async_and_multi_step=self.use_async_and_multi_step)
|
||||
async_callback=self.async_callback)
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
@ -96,7 +97,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
virtual_engine: int = 0
|
||||
async_callback: Optional[Callable] = None
|
||||
use_async_and_multi_step: bool = False
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
||||
get_pythonized_sample_results)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.utils import PyObjectCache
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -37,6 +38,29 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def seq_output_builder():
|
||||
return SequenceOutput(
|
||||
0, 0,
|
||||
{0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
|
||||
|
||||
|
||||
def completion_seq_group_output_builder():
|
||||
return CompletionSequenceGroupOutput([], None)
|
||||
|
||||
|
||||
# Used by pythonization to reduce python object allocations
|
||||
class PythonizationCache:
|
||||
|
||||
def __init__(self):
|
||||
self.cached_seq_output = PyObjectCache(seq_output_builder)
|
||||
self.cached_completion_seq_group_output = PyObjectCache(
|
||||
completion_seq_group_output_builder)
|
||||
|
||||
def reset(self):
|
||||
self.cached_seq_output.reset()
|
||||
self.cached_completion_seq_group_output.reset()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
"""The output of a single model forward pass.
|
||||
@ -59,6 +83,7 @@ class ModelOutput:
|
||||
pythonized: bool = False
|
||||
# On-device tensor containing the logprobs of each token.
|
||||
logprobs: Optional["torch.Tensor"] = None
|
||||
pythonization_cache: Optional[PythonizationCache] = None
|
||||
|
||||
def pythonize(self, input_metadata: "StatefulModelInput",
|
||||
copy_stream: torch.cuda.Stream,
|
||||
@ -97,7 +122,8 @@ class ModelOutput:
|
||||
with torch.cuda.stream(copy_stream):
|
||||
_pythonize_sampler_output(input_metadata, self.sampler_output,
|
||||
pinned_sampled_token_buffer,
|
||||
self.sampled_token_ids, self.logprobs)
|
||||
self.sampled_token_ids, self.logprobs,
|
||||
self.pythonization_cache)
|
||||
|
||||
# Erase the logprobs GPU-side tensor.
|
||||
# Note that although _pythonize_sampler_output() runs in its
|
||||
@ -209,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
self._copy_stream = torch.cuda.Stream()
|
||||
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
self.pythonization_cache = PythonizationCache()
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
|
||||
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
|
||||
@ -237,14 +265,22 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
output_proc_callback: Callable):
|
||||
# Proceed with pythonization and output_proc in order.
|
||||
# Stop on the first one that fails to pythonize
|
||||
output_proc_callback()
|
||||
|
||||
cont = True
|
||||
for model_output in model_input.cached_outputs:
|
||||
if not model_output.pythonized:
|
||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
if model_output.pythonized:
|
||||
output_proc_callback(
|
||||
sampler_output=model_output.sampler_output)
|
||||
ctx = output_proc_callback.keywords["ctx"]
|
||||
is_async = False
|
||||
is_last_step = False
|
||||
ctx.output_queue.append(
|
||||
([model_output.sampler_output
|
||||
], ctx.seq_group_metadata_list,
|
||||
ctx.scheduler_outputs, is_async, is_last_step))
|
||||
output_proc_callback()
|
||||
else:
|
||||
cont = False
|
||||
|
||||
@ -255,21 +291,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
output_proc_callback: Optional[Callable]):
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
has_async_callback = output_proc_callback is not None
|
||||
|
||||
outputs = []
|
||||
for output_id in range(len(model_input.cached_outputs)):
|
||||
is_last_output = output_id == len(model_input.cached_outputs) - 1
|
||||
|
||||
output = model_input.cached_outputs[output_id]
|
||||
if not output.pythonized:
|
||||
is_last_step = output_id == len(model_input.cached_outputs) - 1
|
||||
|
||||
# For non-async case:
|
||||
# -- We simply add the outputs
|
||||
# For async case:
|
||||
# -- Invoke callback, pythonize, add to callback queue and repeat
|
||||
# -- For last output, just add to callback queue
|
||||
if has_async_callback:
|
||||
assert output_proc_callback is not None
|
||||
|
||||
# Invoke callback before pythonize (to overlap with GPU)
|
||||
output_proc_callback()
|
||||
|
||||
# Pythonize
|
||||
if not output.pythonized:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
# For non last step, add to callback queue to chain
|
||||
# callbacks=>pythonize pairs (for GPU overlap)
|
||||
if not is_last_step:
|
||||
ctx = output_proc_callback.keywords[ # type: ignore
|
||||
"ctx"] # type: ignore
|
||||
is_async = False
|
||||
is_last_step = False
|
||||
ctx.output_queue.append(
|
||||
([output.sampler_output
|
||||
], ctx.seq_group_metadata_list,
|
||||
ctx.scheduler_outputs, is_async, is_last_step))
|
||||
else:
|
||||
outputs.append(output.sampler_output)
|
||||
else:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
if model_input.frozen_model_input.use_async_and_multi_step:
|
||||
assert output_proc_callback is not None
|
||||
output_proc_callback(sampler_output=output.sampler_output,
|
||||
is_last_output=is_last_output)
|
||||
|
||||
outputs.append(output.sampler_output)
|
||||
outputs.append(output.sampler_output)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -330,7 +391,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||
|
||||
output_proc_callback = None
|
||||
if frozen_model_input.use_async_and_multi_step:
|
||||
if frozen_model_input.async_callback is not None:
|
||||
output_proc_callback = frozen_model_input.async_callback
|
||||
assert output_proc_callback is not None
|
||||
async_callback = functools.partial(
|
||||
@ -367,7 +428,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
model_input.cached_outputs.append(
|
||||
ModelOutput(output[0], output_ready_event,
|
||||
output[0].sampled_token_ids, False,
|
||||
output[0].logprobs))
|
||||
output[0].logprobs, self.pythonization_cache))
|
||||
|
||||
# These GPU tensors are not required by multi-step;
|
||||
# erase them to ensure they are not pythonized or
|
||||
@ -378,7 +439,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
|
||||
# Pythonize the output if CPU is ahead and the previous step is
|
||||
# ready.
|
||||
if not frozen_model_input.use_async_and_multi_step:
|
||||
if frozen_model_input.async_callback is None:
|
||||
for model_output in model_input.cached_outputs:
|
||||
model_output.maybe_pythonize(model_input,
|
||||
self._copy_stream,
|
||||
@ -397,6 +458,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
if model_input.is_last_step:
|
||||
outputs = self._final_process_outputs(model_input,
|
||||
output_proc_callback)
|
||||
self.pythonization_cache.reset()
|
||||
return outputs
|
||||
|
||||
# should be [SamplerOutput]
|
||||
@ -537,6 +599,7 @@ def _pythonize_sampler_output(
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
logprobs_tensor: Optional[torch.Tensor],
|
||||
cache: Optional[PythonizationCache],
|
||||
) -> None:
|
||||
""" This function is only called when the output tensors are ready.
|
||||
See :class:`ModelOutput`.
|
||||
@ -597,6 +660,9 @@ def _pythonize_sampler_output(
|
||||
|
||||
for sgdx, (seq_group,
|
||||
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
||||
if seq_group.sampling_params.logits_processors:
|
||||
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
||||
"Logits Processors are not supported in multi-step decoding")
|
||||
|
||||
if do_pythonize_logprobs:
|
||||
assert prompt_logprobs is not None
|
||||
@ -621,23 +687,56 @@ def _pythonize_sampler_output(
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids = sample_result
|
||||
parent_ids = [0]
|
||||
seq_outputs: List[SequenceOutput] = []
|
||||
if seq_group.sampling_params.logits_processors:
|
||||
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
||||
"Logits Processors are not supported in multi-step decoding")
|
||||
|
||||
if cache is not None:
|
||||
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
||||
cache.cached_completion_seq_group_output.get_object()
|
||||
completion_seq_group_output.samples.clear()
|
||||
seq_outputs: List[
|
||||
SequenceOutput] = completion_seq_group_output.samples
|
||||
else:
|
||||
seq_outputs = []
|
||||
|
||||
for tdx, (parent_id,
|
||||
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||
(group_sample_logprobs[tdx]
|
||||
if logprobs_are_requested else {
|
||||
next_token_id:
|
||||
Logprob(logprob=float('inf'),
|
||||
rank=None,
|
||||
decoded_token=None)
|
||||
})))
|
||||
output.outputs.append(
|
||||
CompletionSequenceGroupOutput(
|
||||
seq_outputs,
|
||||
(group_prompt_logprobs if logprobs_are_requested else None)))
|
||||
if cache is not None:
|
||||
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
|
||||
)
|
||||
seq_output.parent_seq_id = seq_ids[parent_id]
|
||||
seq_output.output_token = next_token_id
|
||||
|
||||
if logprobs_are_requested:
|
||||
seq_output.logprobs = group_sample_logprobs[tdx]
|
||||
else:
|
||||
logprobs = next(iter(seq_output.logprobs.values()))
|
||||
seq_output.logprobs.clear()
|
||||
|
||||
logprobs.logprob = float('inf')
|
||||
logprobs.rank = None
|
||||
logprobs.decoded_token = None
|
||||
|
||||
seq_output.logprobs[next_token_id] = logprobs
|
||||
|
||||
seq_outputs.append(seq_output)
|
||||
|
||||
else:
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||
(group_sample_logprobs[tdx]
|
||||
if logprobs_are_requested else {
|
||||
next_token_id:
|
||||
Logprob(logprob=float('inf'),
|
||||
rank=None,
|
||||
decoded_token=None)
|
||||
})))
|
||||
if cache is not None:
|
||||
completion_seq_group_output.prompt_logprobs = \
|
||||
group_prompt_logprobs if logprobs_are_requested else None
|
||||
output.outputs.append(completion_seq_group_output)
|
||||
else:
|
||||
output.outputs.append(
|
||||
CompletionSequenceGroupOutput(
|
||||
seq_outputs, (group_prompt_logprobs
|
||||
if logprobs_are_requested else None)))
|
||||
|
||||
assert len(output.outputs) > 0
|
||||
|
||||
@ -67,9 +67,7 @@ class MultiStepWorker(Worker):
|
||||
if execute_model_req.async_callback:
|
||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=execute_model_req.async_callback,
|
||||
use_async_and_multi_step=execute_model_req.
|
||||
use_async_and_multi_step)
|
||||
async_callback=execute_model_req.async_callback)
|
||||
else:
|
||||
# on subsequent steps we reuse the worker input and model input
|
||||
multi_step_state = self.multi_step_states[virtual_engine]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user