[Core] Optimize Async + Multi-step (#8050)
This commit is contained in:
parent
95a178f861
commit
6d646d08a2
@ -103,13 +103,13 @@ async def test_multi_step(
|
|||||||
model,
|
model,
|
||||||
server_args + distributed_args,
|
server_args + distributed_args,
|
||||||
num_logprobs,
|
num_logprobs,
|
||||||
max_wait_seconds=3 * 240)
|
max_wait_seconds=5 * 240)
|
||||||
test_completions = await completions_with_server_args(
|
test_completions = await completions_with_server_args(
|
||||||
prompts,
|
prompts,
|
||||||
model,
|
model,
|
||||||
ms_server_args + distributed_args,
|
ms_server_args + distributed_args,
|
||||||
num_logprobs,
|
num_logprobs,
|
||||||
max_wait_seconds=3 * 240)
|
max_wait_seconds=5 * 240)
|
||||||
|
|
||||||
# Assert multi-step scheduling produces identical tokens
|
# Assert multi-step scheduling produces identical tokens
|
||||||
# to single-step scheduling.
|
# to single-step scheduling.
|
||||||
|
|||||||
@ -280,40 +280,27 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||||
|
|
||||||
# Detect async + multi-step
|
|
||||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
|
||||||
and allow_async_output_proc)
|
|
||||||
|
|
||||||
ctx = self.scheduler_contexts[virtual_engine]
|
ctx = self.scheduler_contexts[virtual_engine]
|
||||||
|
|
||||||
|
# Clear outputs for each new scheduler iteration
|
||||||
|
ctx.request_outputs.clear()
|
||||||
|
|
||||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||||
# This ensures that the scheduler is only called again when the current
|
# This ensures that the scheduler is only called again when the current
|
||||||
# batch has completed.
|
# batch has completed.
|
||||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||||
|
|
||||||
# Clear outputs on scheduler iteration start
|
|
||||||
ctx.request_outputs.clear()
|
|
||||||
|
|
||||||
# Schedule iteration
|
# Schedule iteration
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
allow_async_output_proc
|
allow_async_output_proc
|
||||||
) = self.scheduler[virtual_engine].schedule()
|
) = self.scheduler[virtual_engine].schedule()
|
||||||
|
|
||||||
# Detect async + multi-step
|
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
ctx.scheduler_outputs = scheduler_outputs
|
||||||
and allow_async_output_proc)
|
|
||||||
|
|
||||||
# Maybe switch from async mode to sync mode
|
# Maybe switch from async mode to sync mode
|
||||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
self._process_model_outputs(ctx=ctx)
|
||||||
is_async=True)
|
|
||||||
|
|
||||||
# For async + multi-step, init the queue
|
|
||||||
if use_async_and_multi_step:
|
|
||||||
assert len(ctx.output_queue) == 0
|
|
||||||
assert seq_group_metadata_list is not None
|
|
||||||
ctx.output_queue.append(
|
|
||||||
(None, seq_group_metadata_list, scheduler_outputs))
|
|
||||||
|
|
||||||
if (self.scheduler_config.is_multi_step
|
if (self.scheduler_config.is_multi_step
|
||||||
and scheduler_outputs.num_lookahead_slots > 0):
|
and scheduler_outputs.num_lookahead_slots > 0):
|
||||||
@ -351,26 +338,20 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
last_sampled_token_ids=last_sampled_token_ids)
|
last_sampled_token_ids=last_sampled_token_ids)
|
||||||
|
|
||||||
if allow_async_output_proc:
|
if allow_async_output_proc:
|
||||||
async_callback = self.async_callback_multi_step[
|
execute_model_req.async_callback = self.async_callbacks[
|
||||||
virtual_engine] if use_async_and_multi_step \
|
virtual_engine]
|
||||||
else self.async_callback[virtual_engine]
|
|
||||||
|
|
||||||
execute_model_req.async_callback = async_callback
|
|
||||||
execute_model_req.use_async_and_multi_step = \
|
|
||||||
use_async_and_multi_step
|
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = await self.model_executor.execute_model_async(
|
output = await self.model_executor.execute_model_async(
|
||||||
execute_model_req)
|
execute_model_req)
|
||||||
|
|
||||||
# we need to do this here so that last step's sampled_token_ids can
|
# we need to do this here so that last step's sampled_token_ids can
|
||||||
# be passed to the next iteration for PP.
|
# be passed to the next iteration for PP.
|
||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
self._update_cached_scheduler_output(virtual_engine, output)
|
self._update_cached_scheduler_output(virtual_engine, output)
|
||||||
else:
|
else:
|
||||||
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
|
if len(ctx.output_queue) > 0:
|
||||||
assert not self.scheduler_config.is_multi_step
|
self._process_model_outputs(ctx=ctx)
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
|
||||||
is_async=True)
|
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
# Finish the current step for all the sequence groups.
|
# Finish the current step for all the sequence groups.
|
||||||
@ -384,24 +365,22 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
self.cached_scheduler_outputs[
|
self.cached_scheduler_outputs[
|
||||||
virtual_engine] = SchedulerOutputState()
|
virtual_engine] = SchedulerOutputState()
|
||||||
|
|
||||||
if use_async_and_multi_step:
|
is_async = allow_async_output_proc
|
||||||
# For async + multi-step, clear the queue
|
is_last_step = True
|
||||||
ctx.output_queue.clear()
|
ctx.output_queue.append(
|
||||||
else:
|
(output, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||||
ctx.output_queue.append(
|
is_last_step))
|
||||||
(output, seq_group_metadata_list, scheduler_outputs))
|
|
||||||
|
|
||||||
if output and allow_async_output_proc:
|
if output and allow_async_output_proc:
|
||||||
assert len(
|
assert len(
|
||||||
output
|
output
|
||||||
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
|
) == 1, "Async postprocessor expects only a single output set"
|
||||||
self._advance_to_next_step(
|
self._advance_to_next_step(
|
||||||
output[0], seq_group_metadata_list,
|
output[0], seq_group_metadata_list,
|
||||||
scheduler_outputs.scheduled_seq_groups)
|
scheduler_outputs.scheduled_seq_groups)
|
||||||
|
|
||||||
if not allow_async_output_proc:
|
if not allow_async_output_proc:
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
self._process_model_outputs(ctx=ctx)
|
||||||
is_async=False)
|
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, output)
|
||||||
@ -411,17 +390,12 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# Multi-step case
|
# Multi-step case
|
||||||
if use_async_and_multi_step:
|
return ctx.request_outputs
|
||||||
return []
|
|
||||||
else:
|
|
||||||
ctx.request_outputs = []
|
|
||||||
|
|
||||||
if not self.has_unfinished_requests():
|
if not self.has_unfinished_requests():
|
||||||
# Drain async postprocessor (if exists)
|
# Drain async postprocessor (if exists)
|
||||||
if len(ctx.output_queue) > 0:
|
if len(ctx.output_queue) > 0:
|
||||||
assert not self.scheduler_config.is_multi_step
|
self._process_model_outputs(ctx=ctx)
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
|
||||||
is_async=True)
|
|
||||||
assert len(ctx.output_queue) == 0
|
assert len(ctx.output_queue) == 0
|
||||||
|
|
||||||
return ctx.request_outputs
|
return ctx.request_outputs
|
||||||
@ -640,6 +614,17 @@ class AsyncLLMEngine:
|
|||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
self.engine = self._init_engine(*args, **kwargs)
|
self.engine = self._init_engine(*args, **kwargs)
|
||||||
|
|
||||||
|
# This ensures quick processing of request outputs
|
||||||
|
# so the append to asyncio queues is not delayed,
|
||||||
|
# especially for multi-step.
|
||||||
|
#
|
||||||
|
# TODO: Currently, disabled for engine_use_ray, ask
|
||||||
|
# Cody/Will/Woosuk about this case.
|
||||||
|
self.use_process_request_outputs_callback = not self.engine_use_ray
|
||||||
|
if self.use_process_request_outputs_callback:
|
||||||
|
self.engine.process_request_outputs_callback = \
|
||||||
|
self.process_request_outputs
|
||||||
|
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
"DEPRECATED. `--engine-use-ray` is deprecated and will "
|
"DEPRECATED. `--engine-use-ray` is deprecated and will "
|
||||||
@ -883,13 +868,27 @@ class AsyncLLMEngine:
|
|||||||
request_outputs = await self.engine.step_async(virtual_engine)
|
request_outputs = await self.engine.step_async(virtual_engine)
|
||||||
|
|
||||||
# Put the outputs into the corresponding streams.
|
# Put the outputs into the corresponding streams.
|
||||||
finished = True
|
# If used as a callback, then already invoked inside
|
||||||
|
# LLMEngine's _process_model_outputs
|
||||||
|
if not self.use_process_request_outputs_callback:
|
||||||
|
all_finished = self.process_request_outputs(request_outputs)
|
||||||
|
else:
|
||||||
|
# For callback case, we only need to detect when all
|
||||||
|
# requests are finished
|
||||||
|
all_finished = all(request_output.finished
|
||||||
|
for request_output in request_outputs)
|
||||||
|
|
||||||
|
return not all_finished
|
||||||
|
|
||||||
|
def process_request_outputs(self, request_outputs) -> bool:
|
||||||
|
# Put the outputs into the corresponding streams.
|
||||||
|
all_finished = True
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
self._request_tracker.process_request_output(
|
self._request_tracker.process_request_output(
|
||||||
request_output, verbose=self.log_requests)
|
request_output, verbose=self.log_requests)
|
||||||
finished = finished and request_output.finished
|
all_finished = all_finished and request_output.finished
|
||||||
|
|
||||||
return not finished
|
return all_finished
|
||||||
|
|
||||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
|
|||||||
@ -93,13 +93,14 @@ class SchedulerOutputState:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SchedulerContext:
|
class SchedulerContext:
|
||||||
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
|
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
|
||||||
List[SequenceGroupMetadata],
|
List[SequenceGroupMetadata], SchedulerOutputs,
|
||||||
SchedulerOutputs]] = field(
|
bool,
|
||||||
default_factory=lambda: deque())
|
bool]] = field(default_factory=lambda: deque())
|
||||||
|
|
||||||
request_outputs: List[Union[RequestOutput,
|
request_outputs: List[Union[RequestOutput,
|
||||||
EmbeddingRequestOutput]] = field(
|
EmbeddingRequestOutput]] = field(
|
||||||
default_factory=lambda: [])
|
default_factory=lambda: [])
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||||
|
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
@ -357,6 +358,26 @@ class LLMEngine:
|
|||||||
# different process.
|
# different process.
|
||||||
self.tokenizer.ping()
|
self.tokenizer.ping()
|
||||||
|
|
||||||
|
self.cached_scheduler_outputs = [
|
||||||
|
SchedulerOutputState()
|
||||||
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.scheduler_contexts = [
|
||||||
|
SchedulerContext()
|
||||||
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.async_callbacks = [
|
||||||
|
functools.partial(self._process_model_outputs,
|
||||||
|
ctx=self.scheduler_contexts[v_id])
|
||||||
|
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Currently used by AsyncLLMEngine to ensure quick append
|
||||||
|
# of request outputs to asyncio queues
|
||||||
|
self.process_request_outputs_callback = None
|
||||||
|
|
||||||
# Create the scheduler.
|
# Create the scheduler.
|
||||||
# NOTE: the cache_config here have been updated with the numbers of
|
# NOTE: the cache_config here have been updated with the numbers of
|
||||||
# GPU and CPU blocks, which are profiled in the distributed executor.
|
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||||
@ -364,9 +385,7 @@ class LLMEngine:
|
|||||||
Scheduler(
|
Scheduler(
|
||||||
scheduler_config, cache_config, lora_config,
|
scheduler_config, cache_config, lora_config,
|
||||||
parallel_config.pipeline_parallel_size,
|
parallel_config.pipeline_parallel_size,
|
||||||
functools.partial(self._process_model_outputs,
|
self.async_callbacks[v_id]
|
||||||
virtual_engine=v_id,
|
|
||||||
is_async=True)
|
|
||||||
if model_config.use_async_output_proc else None)
|
if model_config.use_async_output_proc else None)
|
||||||
for v_id in range(parallel_config.pipeline_parallel_size)
|
for v_id in range(parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
@ -417,30 +436,6 @@ class LLMEngine:
|
|||||||
),
|
),
|
||||||
))
|
))
|
||||||
|
|
||||||
self.cached_scheduler_outputs = [
|
|
||||||
SchedulerOutputState()
|
|
||||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.scheduler_contexts = [
|
|
||||||
SchedulerContext()
|
|
||||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.async_callback = [
|
|
||||||
functools.partial(self._process_model_outputs,
|
|
||||||
virtual_engine=v_id,
|
|
||||||
is_async=True)
|
|
||||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.async_callback_multi_step = [
|
|
||||||
functools.partial(self._process_model_outputs,
|
|
||||||
virtual_engine=v_id,
|
|
||||||
is_async=False)
|
|
||||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _initialize_kv_caches(self) -> None:
|
def _initialize_kv_caches(self) -> None:
|
||||||
"""Initialize the KV cache in the worker(s).
|
"""Initialize the KV cache in the worker(s).
|
||||||
|
|
||||||
@ -1249,11 +1244,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def _process_model_outputs(self,
|
def _process_model_outputs(self, ctx: SchedulerContext) -> None:
|
||||||
virtual_engine: int,
|
|
||||||
is_async: bool,
|
|
||||||
sampler_output: Optional[SamplerOutput] = None,
|
|
||||||
is_last_output: bool = False) -> None:
|
|
||||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||||
|
|
||||||
virtual_engine: The engine id to operate on
|
virtual_engine: The engine id to operate on
|
||||||
@ -1273,24 +1264,12 @@ class LLMEngine:
|
|||||||
"""
|
"""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
is_multi_step = sampler_output is not None
|
|
||||||
|
|
||||||
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
|
|
||||||
|
|
||||||
if len(ctx.output_queue) == 0:
|
if len(ctx.output_queue) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if is_multi_step:
|
# Get pending async postprocessor
|
||||||
# Async + multi-step case
|
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||||
(outputs, seq_group_metadata_list,
|
is_last_step) = ctx.output_queue.popleft()
|
||||||
scheduler_outputs) = ctx.output_queue[0]
|
|
||||||
assert outputs is None
|
|
||||||
outputs = [sampler_output]
|
|
||||||
else:
|
|
||||||
# Async standard case
|
|
||||||
(outputs, seq_group_metadata_list,
|
|
||||||
scheduler_outputs) = ctx.output_queue.popleft()
|
|
||||||
|
|
||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
@ -1306,6 +1285,7 @@ class LLMEngine:
|
|||||||
outputs_by_sequence_group = outputs
|
outputs_by_sequence_group = outputs
|
||||||
|
|
||||||
finished_before: List[int] = []
|
finished_before: List[int] = []
|
||||||
|
finished_now: List[int] = []
|
||||||
for i, seq_group_meta in enumerate(seq_group_metadata_list):
|
for i, seq_group_meta in enumerate(seq_group_metadata_list):
|
||||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||||
|
|
||||||
@ -1343,26 +1323,44 @@ class LLMEngine:
|
|||||||
|
|
||||||
if self.model_config.embedding_mode:
|
if self.model_config.embedding_mode:
|
||||||
self._process_sequence_group_outputs(seq_group, output)
|
self._process_sequence_group_outputs(seq_group, output)
|
||||||
continue
|
else:
|
||||||
|
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||||
|
if seq_group_meta.do_sample:
|
||||||
|
self.output_processor.process_outputs(
|
||||||
|
seq_group, output, is_async)
|
||||||
|
|
||||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
if seq_group.is_finished():
|
||||||
if seq_group_meta.do_sample:
|
finished_now.append(i)
|
||||||
self.output_processor.process_outputs(seq_group, output,
|
|
||||||
is_async)
|
|
||||||
|
|
||||||
# For async + multi-step, free finished seqs and create outputs
|
# Generate outputs for the requests that finished this iteration
|
||||||
# only on the final step.
|
for i in finished_now:
|
||||||
if is_multi_step and not is_last_output:
|
|
||||||
return
|
|
||||||
|
|
||||||
for scheduler in self.scheduler:
|
|
||||||
scheduler.free_finished_seq_groups()
|
|
||||||
|
|
||||||
# Create the outputs.
|
|
||||||
for i, _ in enumerate(seq_group_metadata_list):
|
|
||||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||||
|
|
||||||
if not is_multi_step and i in finished_before:
|
seq_group = scheduled_seq_group.seq_group
|
||||||
|
seq_group.maybe_set_first_token_time(now)
|
||||||
|
request_output = RequestOutputFactory.create(seq_group)
|
||||||
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
|
# Free currently finished requests
|
||||||
|
if finished_now:
|
||||||
|
for scheduler in self.scheduler:
|
||||||
|
scheduler.free_finished_seq_groups()
|
||||||
|
|
||||||
|
# For multi-step, do not create outputs each iteration
|
||||||
|
if not is_last_step:
|
||||||
|
# Immediately process request outputs here (if callback is given)
|
||||||
|
if (finished_now
|
||||||
|
and self.process_request_outputs_callback is not None):
|
||||||
|
self.process_request_outputs_callback(ctx.request_outputs)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create the outputs
|
||||||
|
# Note: scheduled_seq_groups and seq_group_metadata_list
|
||||||
|
# must match with the indices
|
||||||
|
for i, scheduled_seq_group in enumerate(
|
||||||
|
scheduler_outputs.scheduled_seq_groups):
|
||||||
|
|
||||||
|
if i in finished_before or i in finished_now:
|
||||||
continue # Avoids double processing
|
continue # Avoids double processing
|
||||||
|
|
||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
@ -1376,11 +1374,15 @@ class LLMEngine:
|
|||||||
request_output = RequestOutputFactory.create(seq_group)
|
request_output = RequestOutputFactory.create(seq_group)
|
||||||
ctx.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
# For async + multi-step, do stats only on the last output.
|
# Immediately process request outputs here (if callback is given)
|
||||||
# Otherwise, do stats if the execution is async
|
if (ctx.request_outputs
|
||||||
do_stats = is_multi_step or is_async
|
and self.process_request_outputs_callback is not None):
|
||||||
|
self.process_request_outputs_callback(ctx.request_outputs)
|
||||||
|
|
||||||
if do_stats:
|
# For async case, we need to record the stats here.
|
||||||
|
# For non-async case, the stats are done in the
|
||||||
|
# LLMEngine/AsyncLLMEngine directly
|
||||||
|
if is_async:
|
||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, outputs, finished_before)
|
self.do_log_stats(scheduler_outputs, outputs, finished_before)
|
||||||
|
|
||||||
@ -1485,40 +1487,26 @@ class LLMEngine:
|
|||||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||||
|
|
||||||
# Detect async + multi-step
|
|
||||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
|
||||||
and allow_async_output_proc)
|
|
||||||
|
|
||||||
ctx = self.scheduler_contexts[virtual_engine]
|
ctx = self.scheduler_contexts[virtual_engine]
|
||||||
|
|
||||||
|
# Clear outputs for each new scheduler iteration
|
||||||
|
ctx.request_outputs.clear()
|
||||||
|
|
||||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||||
# This ensures that the scheduler is only called again when the current
|
# This ensures that the scheduler is only called again when the current
|
||||||
# batch has completed.
|
# batch has completed.
|
||||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||||
|
|
||||||
# Clear outputs on scheduler iteration start
|
|
||||||
ctx.request_outputs.clear()
|
|
||||||
|
|
||||||
# Schedule iteration
|
# Schedule iteration
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
allow_async_output_proc
|
allow_async_output_proc
|
||||||
) = self.scheduler[virtual_engine].schedule()
|
) = self.scheduler[virtual_engine].schedule()
|
||||||
|
|
||||||
# Detect async + multi-step
|
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
ctx.scheduler_outputs = scheduler_outputs
|
||||||
and allow_async_output_proc)
|
|
||||||
|
|
||||||
# Maybe switch from async mode to sync mode
|
# Maybe switch from async mode to sync mode
|
||||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
self._process_model_outputs(ctx=ctx)
|
||||||
is_async=True)
|
|
||||||
|
|
||||||
# For async + multi-step, init the queue
|
|
||||||
if use_async_and_multi_step:
|
|
||||||
assert len(ctx.output_queue) == 0
|
|
||||||
assert seq_group_metadata_list is not None
|
|
||||||
ctx.output_queue.append(
|
|
||||||
(None, seq_group_metadata_list, scheduler_outputs))
|
|
||||||
|
|
||||||
if (self.scheduler_config.is_multi_step
|
if (self.scheduler_config.is_multi_step
|
||||||
and scheduler_outputs.num_lookahead_slots > 0):
|
and scheduler_outputs.num_lookahead_slots > 0):
|
||||||
@ -1555,13 +1543,8 @@ class LLMEngine:
|
|||||||
last_sampled_token_ids=last_sampled_token_ids)
|
last_sampled_token_ids=last_sampled_token_ids)
|
||||||
|
|
||||||
if allow_async_output_proc:
|
if allow_async_output_proc:
|
||||||
async_callback = self.async_callback_multi_step[
|
execute_model_req.async_callback = self.async_callbacks[
|
||||||
virtual_engine] if use_async_and_multi_step \
|
virtual_engine]
|
||||||
else self.async_callback[virtual_engine]
|
|
||||||
|
|
||||||
execute_model_req.async_callback = async_callback
|
|
||||||
execute_model_req.use_async_and_multi_step = \
|
|
||||||
use_async_and_multi_step
|
|
||||||
|
|
||||||
output = self.model_executor.execute_model(
|
output = self.model_executor.execute_model(
|
||||||
execute_model_req=execute_model_req)
|
execute_model_req=execute_model_req)
|
||||||
@ -1573,10 +1556,8 @@ class LLMEngine:
|
|||||||
else:
|
else:
|
||||||
# Nothing scheduled => If there is pending async postprocessor,
|
# Nothing scheduled => If there is pending async postprocessor,
|
||||||
# then finish it here.
|
# then finish it here.
|
||||||
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
|
if len(ctx.output_queue) > 0:
|
||||||
assert not self.scheduler_config.is_multi_step
|
self._process_model_outputs(ctx=ctx)
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
|
||||||
is_async=True)
|
|
||||||
# No outputs in this case
|
# No outputs in this case
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
@ -1590,28 +1571,24 @@ class LLMEngine:
|
|||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||||
|
|
||||||
if use_async_and_multi_step:
|
# Add results to the output_queue
|
||||||
# For async + multi-step, clear the queue
|
is_async = allow_async_output_proc
|
||||||
ctx.output_queue.clear()
|
is_last_step = True
|
||||||
else:
|
ctx.output_queue.append(
|
||||||
# Add results to the output_queue
|
(output, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||||
# (for async or non-async postprocessing)
|
is_last_step))
|
||||||
ctx.output_queue.append(
|
|
||||||
(output, seq_group_metadata_list, scheduler_outputs))
|
|
||||||
|
|
||||||
if output and allow_async_output_proc:
|
if output and allow_async_output_proc:
|
||||||
assert len(output) == 1, (
|
assert len(output) == 1, (
|
||||||
"Multi step decoding does not work "
|
"Async postprocessor expects only a single output set")
|
||||||
"with async output processing.")
|
|
||||||
|
|
||||||
self._advance_to_next_step(
|
self._advance_to_next_step(
|
||||||
output[0], seq_group_metadata_list,
|
output[0], seq_group_metadata_list,
|
||||||
scheduler_outputs.scheduled_seq_groups)
|
scheduler_outputs.scheduled_seq_groups)
|
||||||
|
|
||||||
# Check if need to run the usual non-async path
|
# Check if need to run the usual non-async path
|
||||||
if not allow_async_output_proc:
|
if not allow_async_output_proc:
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
self._process_model_outputs(ctx=ctx)
|
||||||
is_async=False)
|
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, output)
|
||||||
@ -1620,17 +1597,12 @@ class LLMEngine:
|
|||||||
self.do_tracing(scheduler_outputs)
|
self.do_tracing(scheduler_outputs)
|
||||||
else:
|
else:
|
||||||
# Multi-step case
|
# Multi-step case
|
||||||
if use_async_and_multi_step:
|
return ctx.request_outputs
|
||||||
return []
|
|
||||||
else:
|
|
||||||
ctx.request_outputs = []
|
|
||||||
|
|
||||||
if not self.has_unfinished_requests():
|
if not self.has_unfinished_requests():
|
||||||
# Drain async postprocessor (if exists)
|
# Drain async postprocessor (if exists)
|
||||||
if len(ctx.output_queue) > 0:
|
if len(ctx.output_queue) > 0:
|
||||||
assert not self.scheduler_config.is_multi_step
|
self._process_model_outputs(ctx=ctx)
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
|
||||||
is_async=True)
|
|
||||||
assert len(ctx.output_queue) == 0
|
assert len(ctx.output_queue) == 0
|
||||||
|
|
||||||
# Stop the execute model loop in parallel workers until there are
|
# Stop the execute model loop in parallel workers until there are
|
||||||
|
|||||||
@ -85,9 +85,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
no tokens need to be appended since it is already done
|
no tokens need to be appended since it is already done
|
||||||
externally (before the next schedule() call)
|
externally (before the next schedule() call)
|
||||||
"""
|
"""
|
||||||
# TODO: Add support for async if necessary
|
|
||||||
assert not is_async
|
|
||||||
|
|
||||||
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
||||||
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
||||||
# if a client disconnects from the api server.
|
# if a client disconnects from the api server.
|
||||||
@ -101,19 +98,41 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
"Beam search not supported in multi-step decoding.")
|
"Beam search not supported in multi-step decoding.")
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
|
|
||||||
# Since there's only one sequence per sequence group, we can take the
|
if is_async:
|
||||||
# first sample.
|
# Async case: We process tokens one by one. Here, we know the token
|
||||||
samples = [output.samples[0] for output in outputs]
|
# was already appended, so we only need to do the rest of the
|
||||||
|
# postprocessor: Detokenization + stopping logic
|
||||||
|
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
||||||
|
else:
|
||||||
|
# Standard multi-step case
|
||||||
|
|
||||||
# -1 means the output token is not valid (eg. due to spec decode
|
# Since there's only one sequence per sequence group,
|
||||||
# rejecting tokens).
|
# we can take the first sample.
|
||||||
valid_samples = [
|
samples = [output.samples[0] for output in outputs]
|
||||||
sample for sample in samples if sample.output_token != -1
|
|
||||||
]
|
|
||||||
assert valid_samples
|
|
||||||
|
|
||||||
self._process_seq_outputs(seq, valid_samples,
|
# -1 means the output token is not valid (eg. due to spec decode
|
||||||
sequence_group.sampling_params)
|
# rejecting tokens).
|
||||||
|
valid_samples = [
|
||||||
|
sample for sample in samples if sample.output_token != -1
|
||||||
|
]
|
||||||
|
assert valid_samples
|
||||||
|
|
||||||
|
self._process_seq_outputs(seq, valid_samples,
|
||||||
|
sequence_group.sampling_params)
|
||||||
|
|
||||||
|
def _process_decode_and_stop(self, seq: Sequence,
|
||||||
|
sampling_params: SamplingParams) -> None:
|
||||||
|
new_char_count = 0
|
||||||
|
if sampling_params.detokenize:
|
||||||
|
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||||
|
seq, sampling_params)
|
||||||
|
|
||||||
|
# TODO(sang): Support lora.
|
||||||
|
self.stop_checker.maybe_stop_sequence(
|
||||||
|
seq,
|
||||||
|
new_char_count=new_char_count,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
def _process_seq_outputs(self, seq: Sequence,
|
def _process_seq_outputs(self, seq: Sequence,
|
||||||
valid_samples: List[SequenceOutput],
|
valid_samples: List[SequenceOutput],
|
||||||
@ -151,16 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
logprobs=output_logprob,
|
logprobs=output_logprob,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_char_count = 0
|
self._process_decode_and_stop(seq, sampling_params)
|
||||||
if sampling_params.detokenize:
|
|
||||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
|
||||||
seq, sampling_params)
|
|
||||||
|
|
||||||
# TODO(sang): Support lora.
|
|
||||||
self.stop_checker.maybe_stop_sequence(
|
|
||||||
seq,
|
|
||||||
new_char_count=new_char_count,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
break
|
break
|
||||||
|
|||||||
@ -1225,7 +1225,6 @@ class ExecuteModelRequest(
|
|||||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
# Async callback
|
# Async callback
|
||||||
async_callback: Optional[Callable] = None
|
async_callback: Optional[Callable] = None
|
||||||
use_async_and_multi_step: bool = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_first_multi_step(self) -> bool:
|
def is_first_multi_step(self) -> bool:
|
||||||
@ -1272,5 +1271,4 @@ class ExecuteModelRequest(
|
|||||||
finished_requests_ids=self.finished_requests_ids,
|
finished_requests_ids=self.finished_requests_ids,
|
||||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||||
if self.last_sampled_token_ids is not None else None,
|
if self.last_sampled_token_ids is not None else None,
|
||||||
async_callback=self.async_callback,
|
async_callback=self.async_callback)
|
||||||
use_async_and_multi_step=self.use_async_and_multi_step)
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig)
|
PromptAdapterConfig, SchedulerConfig)
|
||||||
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
@ -96,7 +97,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
finished_requests_ids: Optional[List[str]] = None
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
virtual_engine: int = 0
|
virtual_engine: int = 0
|
||||||
async_callback: Optional[Callable] = None
|
async_callback: Optional[Callable] = None
|
||||||
use_async_and_multi_step: bool = False
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||||
|
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
|||||||
get_pythonized_sample_results)
|
get_pythonized_sample_results)
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||||
|
from vllm.utils import PyObjectCache
|
||||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
from vllm.worker.model_runner_base import (
|
from vllm.worker.model_runner_base import (
|
||||||
@ -37,6 +38,29 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def seq_output_builder():
|
||||||
|
return SequenceOutput(
|
||||||
|
0, 0,
|
||||||
|
{0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
|
||||||
|
|
||||||
|
|
||||||
|
def completion_seq_group_output_builder():
|
||||||
|
return CompletionSequenceGroupOutput([], None)
|
||||||
|
|
||||||
|
|
||||||
|
# Used by pythonization to reduce python object allocations
|
||||||
|
class PythonizationCache:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.cached_seq_output = PyObjectCache(seq_output_builder)
|
||||||
|
self.cached_completion_seq_group_output = PyObjectCache(
|
||||||
|
completion_seq_group_output_builder)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.cached_seq_output.reset()
|
||||||
|
self.cached_completion_seq_group_output.reset()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelOutput:
|
class ModelOutput:
|
||||||
"""The output of a single model forward pass.
|
"""The output of a single model forward pass.
|
||||||
@ -59,6 +83,7 @@ class ModelOutput:
|
|||||||
pythonized: bool = False
|
pythonized: bool = False
|
||||||
# On-device tensor containing the logprobs of each token.
|
# On-device tensor containing the logprobs of each token.
|
||||||
logprobs: Optional["torch.Tensor"] = None
|
logprobs: Optional["torch.Tensor"] = None
|
||||||
|
pythonization_cache: Optional[PythonizationCache] = None
|
||||||
|
|
||||||
def pythonize(self, input_metadata: "StatefulModelInput",
|
def pythonize(self, input_metadata: "StatefulModelInput",
|
||||||
copy_stream: torch.cuda.Stream,
|
copy_stream: torch.cuda.Stream,
|
||||||
@ -97,7 +122,8 @@ class ModelOutput:
|
|||||||
with torch.cuda.stream(copy_stream):
|
with torch.cuda.stream(copy_stream):
|
||||||
_pythonize_sampler_output(input_metadata, self.sampler_output,
|
_pythonize_sampler_output(input_metadata, self.sampler_output,
|
||||||
pinned_sampled_token_buffer,
|
pinned_sampled_token_buffer,
|
||||||
self.sampled_token_ids, self.logprobs)
|
self.sampled_token_ids, self.logprobs,
|
||||||
|
self.pythonization_cache)
|
||||||
|
|
||||||
# Erase the logprobs GPU-side tensor.
|
# Erase the logprobs GPU-side tensor.
|
||||||
# Note that although _pythonize_sampler_output() runs in its
|
# Note that although _pythonize_sampler_output() runs in its
|
||||||
@ -209,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
self._copy_stream = torch.cuda.Stream()
|
self._copy_stream = torch.cuda.Stream()
|
||||||
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
self.pythonization_cache = PythonizationCache()
|
||||||
|
|
||||||
def make_model_input_from_broadcasted_tensor_dict(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
|
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
|
||||||
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
|
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
|
||||||
@ -237,14 +265,22 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
output_proc_callback: Callable):
|
output_proc_callback: Callable):
|
||||||
# Proceed with pythonization and output_proc in order.
|
# Proceed with pythonization and output_proc in order.
|
||||||
# Stop on the first one that fails to pythonize
|
# Stop on the first one that fails to pythonize
|
||||||
|
output_proc_callback()
|
||||||
|
|
||||||
cont = True
|
cont = True
|
||||||
for model_output in model_input.cached_outputs:
|
for model_output in model_input.cached_outputs:
|
||||||
if not model_output.pythonized:
|
if not model_output.pythonized:
|
||||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||||
self.pinned_sampled_token_ids)
|
self.pinned_sampled_token_ids)
|
||||||
if model_output.pythonized:
|
if model_output.pythonized:
|
||||||
output_proc_callback(
|
ctx = output_proc_callback.keywords["ctx"]
|
||||||
sampler_output=model_output.sampler_output)
|
is_async = False
|
||||||
|
is_last_step = False
|
||||||
|
ctx.output_queue.append(
|
||||||
|
([model_output.sampler_output
|
||||||
|
], ctx.seq_group_metadata_list,
|
||||||
|
ctx.scheduler_outputs, is_async, is_last_step))
|
||||||
|
output_proc_callback()
|
||||||
else:
|
else:
|
||||||
cont = False
|
cont = False
|
||||||
|
|
||||||
@ -255,21 +291,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
output_proc_callback: Optional[Callable]):
|
output_proc_callback: Optional[Callable]):
|
||||||
assert model_input.frozen_model_input is not None
|
assert model_input.frozen_model_input is not None
|
||||||
|
|
||||||
|
has_async_callback = output_proc_callback is not None
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for output_id in range(len(model_input.cached_outputs)):
|
for output_id in range(len(model_input.cached_outputs)):
|
||||||
is_last_output = output_id == len(model_input.cached_outputs) - 1
|
|
||||||
|
|
||||||
output = model_input.cached_outputs[output_id]
|
output = model_input.cached_outputs[output_id]
|
||||||
if not output.pythonized:
|
is_last_step = output_id == len(model_input.cached_outputs) - 1
|
||||||
|
|
||||||
|
# For non-async case:
|
||||||
|
# -- We simply add the outputs
|
||||||
|
# For async case:
|
||||||
|
# -- Invoke callback, pythonize, add to callback queue and repeat
|
||||||
|
# -- For last output, just add to callback queue
|
||||||
|
if has_async_callback:
|
||||||
|
assert output_proc_callback is not None
|
||||||
|
|
||||||
|
# Invoke callback before pythonize (to overlap with GPU)
|
||||||
|
output_proc_callback()
|
||||||
|
|
||||||
|
# Pythonize
|
||||||
|
if not output.pythonized:
|
||||||
|
output.pythonize(model_input, self._copy_stream,
|
||||||
|
self.pinned_sampled_token_ids)
|
||||||
|
|
||||||
|
# For non last step, add to callback queue to chain
|
||||||
|
# callbacks=>pythonize pairs (for GPU overlap)
|
||||||
|
if not is_last_step:
|
||||||
|
ctx = output_proc_callback.keywords[ # type: ignore
|
||||||
|
"ctx"] # type: ignore
|
||||||
|
is_async = False
|
||||||
|
is_last_step = False
|
||||||
|
ctx.output_queue.append(
|
||||||
|
([output.sampler_output
|
||||||
|
], ctx.seq_group_metadata_list,
|
||||||
|
ctx.scheduler_outputs, is_async, is_last_step))
|
||||||
|
else:
|
||||||
|
outputs.append(output.sampler_output)
|
||||||
|
else:
|
||||||
output.pythonize(model_input, self._copy_stream,
|
output.pythonize(model_input, self._copy_stream,
|
||||||
self.pinned_sampled_token_ids)
|
self.pinned_sampled_token_ids)
|
||||||
|
outputs.append(output.sampler_output)
|
||||||
if model_input.frozen_model_input.use_async_and_multi_step:
|
|
||||||
assert output_proc_callback is not None
|
|
||||||
output_proc_callback(sampler_output=output.sampler_output,
|
|
||||||
is_last_output=is_last_output)
|
|
||||||
|
|
||||||
outputs.append(output.sampler_output)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@ -330,7 +391,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
model_input, model_input.cached_outputs[-1].sampler_output)
|
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||||
|
|
||||||
output_proc_callback = None
|
output_proc_callback = None
|
||||||
if frozen_model_input.use_async_and_multi_step:
|
if frozen_model_input.async_callback is not None:
|
||||||
output_proc_callback = frozen_model_input.async_callback
|
output_proc_callback = frozen_model_input.async_callback
|
||||||
assert output_proc_callback is not None
|
assert output_proc_callback is not None
|
||||||
async_callback = functools.partial(
|
async_callback = functools.partial(
|
||||||
@ -367,7 +428,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
model_input.cached_outputs.append(
|
model_input.cached_outputs.append(
|
||||||
ModelOutput(output[0], output_ready_event,
|
ModelOutput(output[0], output_ready_event,
|
||||||
output[0].sampled_token_ids, False,
|
output[0].sampled_token_ids, False,
|
||||||
output[0].logprobs))
|
output[0].logprobs, self.pythonization_cache))
|
||||||
|
|
||||||
# These GPU tensors are not required by multi-step;
|
# These GPU tensors are not required by multi-step;
|
||||||
# erase them to ensure they are not pythonized or
|
# erase them to ensure they are not pythonized or
|
||||||
@ -378,7 +439,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
|
|
||||||
# Pythonize the output if CPU is ahead and the previous step is
|
# Pythonize the output if CPU is ahead and the previous step is
|
||||||
# ready.
|
# ready.
|
||||||
if not frozen_model_input.use_async_and_multi_step:
|
if frozen_model_input.async_callback is None:
|
||||||
for model_output in model_input.cached_outputs:
|
for model_output in model_input.cached_outputs:
|
||||||
model_output.maybe_pythonize(model_input,
|
model_output.maybe_pythonize(model_input,
|
||||||
self._copy_stream,
|
self._copy_stream,
|
||||||
@ -397,6 +458,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
if model_input.is_last_step:
|
if model_input.is_last_step:
|
||||||
outputs = self._final_process_outputs(model_input,
|
outputs = self._final_process_outputs(model_input,
|
||||||
output_proc_callback)
|
output_proc_callback)
|
||||||
|
self.pythonization_cache.reset()
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# should be [SamplerOutput]
|
# should be [SamplerOutput]
|
||||||
@ -537,6 +599,7 @@ def _pythonize_sampler_output(
|
|||||||
pinned_sampled_token_buffer: torch.Tensor,
|
pinned_sampled_token_buffer: torch.Tensor,
|
||||||
sampled_token_ids: torch.Tensor,
|
sampled_token_ids: torch.Tensor,
|
||||||
logprobs_tensor: Optional[torch.Tensor],
|
logprobs_tensor: Optional[torch.Tensor],
|
||||||
|
cache: Optional[PythonizationCache],
|
||||||
) -> None:
|
) -> None:
|
||||||
""" This function is only called when the output tensors are ready.
|
""" This function is only called when the output tensors are ready.
|
||||||
See :class:`ModelOutput`.
|
See :class:`ModelOutput`.
|
||||||
@ -597,6 +660,9 @@ def _pythonize_sampler_output(
|
|||||||
|
|
||||||
for sgdx, (seq_group,
|
for sgdx, (seq_group,
|
||||||
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
||||||
|
if seq_group.sampling_params.logits_processors:
|
||||||
|
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
||||||
|
"Logits Processors are not supported in multi-step decoding")
|
||||||
|
|
||||||
if do_pythonize_logprobs:
|
if do_pythonize_logprobs:
|
||||||
assert prompt_logprobs is not None
|
assert prompt_logprobs is not None
|
||||||
@ -621,23 +687,56 @@ def _pythonize_sampler_output(
|
|||||||
seq_ids = seq_group.seq_ids
|
seq_ids = seq_group.seq_ids
|
||||||
next_token_ids = sample_result
|
next_token_ids = sample_result
|
||||||
parent_ids = [0]
|
parent_ids = [0]
|
||||||
seq_outputs: List[SequenceOutput] = []
|
|
||||||
if seq_group.sampling_params.logits_processors:
|
if cache is not None:
|
||||||
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
||||||
"Logits Processors are not supported in multi-step decoding")
|
cache.cached_completion_seq_group_output.get_object()
|
||||||
|
completion_seq_group_output.samples.clear()
|
||||||
|
seq_outputs: List[
|
||||||
|
SequenceOutput] = completion_seq_group_output.samples
|
||||||
|
else:
|
||||||
|
seq_outputs = []
|
||||||
|
|
||||||
for tdx, (parent_id,
|
for tdx, (parent_id,
|
||||||
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
|
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
|
||||||
seq_outputs.append(
|
if cache is not None:
|
||||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
|
||||||
(group_sample_logprobs[tdx]
|
)
|
||||||
if logprobs_are_requested else {
|
seq_output.parent_seq_id = seq_ids[parent_id]
|
||||||
next_token_id:
|
seq_output.output_token = next_token_id
|
||||||
Logprob(logprob=float('inf'),
|
|
||||||
rank=None,
|
if logprobs_are_requested:
|
||||||
decoded_token=None)
|
seq_output.logprobs = group_sample_logprobs[tdx]
|
||||||
})))
|
else:
|
||||||
output.outputs.append(
|
logprobs = next(iter(seq_output.logprobs.values()))
|
||||||
CompletionSequenceGroupOutput(
|
seq_output.logprobs.clear()
|
||||||
seq_outputs,
|
|
||||||
(group_prompt_logprobs if logprobs_are_requested else None)))
|
logprobs.logprob = float('inf')
|
||||||
|
logprobs.rank = None
|
||||||
|
logprobs.decoded_token = None
|
||||||
|
|
||||||
|
seq_output.logprobs[next_token_id] = logprobs
|
||||||
|
|
||||||
|
seq_outputs.append(seq_output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
seq_outputs.append(
|
||||||
|
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||||
|
(group_sample_logprobs[tdx]
|
||||||
|
if logprobs_are_requested else {
|
||||||
|
next_token_id:
|
||||||
|
Logprob(logprob=float('inf'),
|
||||||
|
rank=None,
|
||||||
|
decoded_token=None)
|
||||||
|
})))
|
||||||
|
if cache is not None:
|
||||||
|
completion_seq_group_output.prompt_logprobs = \
|
||||||
|
group_prompt_logprobs if logprobs_are_requested else None
|
||||||
|
output.outputs.append(completion_seq_group_output)
|
||||||
|
else:
|
||||||
|
output.outputs.append(
|
||||||
|
CompletionSequenceGroupOutput(
|
||||||
|
seq_outputs, (group_prompt_logprobs
|
||||||
|
if logprobs_are_requested else None)))
|
||||||
|
|
||||||
assert len(output.outputs) > 0
|
assert len(output.outputs) > 0
|
||||||
|
|||||||
@ -67,9 +67,7 @@ class MultiStepWorker(Worker):
|
|||||||
if execute_model_req.async_callback:
|
if execute_model_req.async_callback:
|
||||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||||
model_input.frozen_model_input,
|
model_input.frozen_model_input,
|
||||||
async_callback=execute_model_req.async_callback,
|
async_callback=execute_model_req.async_callback)
|
||||||
use_async_and_multi_step=execute_model_req.
|
|
||||||
use_async_and_multi_step)
|
|
||||||
else:
|
else:
|
||||||
# on subsequent steps we reuse the worker input and model input
|
# on subsequent steps we reuse the worker input and model input
|
||||||
multi_step_state = self.multi_step_states[virtual_engine]
|
multi_step_state = self.multi_step_states[virtual_engine]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user