[Core] Async_output_proc: Add virtual engine support (towards pipeline parallel) (#7911)

This commit is contained in:
Alexander Matveev 2024-08-28 03:02:30 -04:00 committed by GitHub
parent 51f86bf487
commit f508e03e7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 123 additions and 68 deletions

View File

@ -302,7 +302,7 @@ class Scheduler:
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback_fn: Optional[Callable] = None,
output_proc_callback: Optional[Callable] = None,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
@ -376,8 +376,8 @@ class Scheduler:
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self.output_proc_callback_fn = output_proc_callback_fn
self.use_async_output_proc = self.output_proc_callback_fn is not None
self.output_proc_callback = output_proc_callback
self.use_async_output_proc = self.output_proc_callback is not None
self.num_cache_iters = 2 if self.use_async_output_proc else 1
self.cache_id = 0
@ -573,8 +573,8 @@ class Scheduler:
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback_fn is not None
self.output_proc_callback_fn(is_async=True)
assert self.output_proc_callback is not None
self.output_proc_callback()
self.running = tmp
while not self._can_append_slots(seq_group):
@ -1091,7 +1091,6 @@ class Scheduler:
no_beam_search = seq_group.sampling_params is None or (
seq_group.sampling_params.best_of == 1
and not seq_group.sampling_params.use_beam_search)
return no_beam_search
def schedule(

View File

@ -279,10 +279,16 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
ctx = self.scheduler_contexts[virtual_engine]
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
@ -290,8 +296,9 @@ class _AsyncLLMEngine(LLMEngine):
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
@ -332,8 +339,8 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
execute_model_req.async_callback = self.async_callback[
virtual_engine]
# Execute the model.
output = await self.model_executor.execute_model_async(
@ -343,9 +350,10 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(self.output_queue) > 0:
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
output = []
# Finish the current step for all the sequence groups.
@ -360,7 +368,7 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine] = SchedulerOutputState()
# Cache results in engine
self.output_queue.append(
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc:
@ -372,7 +380,8 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
@ -381,9 +390,17 @@ class _AsyncLLMEngine(LLMEngine):
self.do_tracing(scheduler_outputs)
else:
self.request_outputs = []
ctx.request_outputs = []
return self.request_outputs
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
assert len(ctx.output_queue) == 0
return ctx.request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""

View File

@ -1,7 +1,8 @@
import functools
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
@ -88,6 +89,17 @@ class SchedulerOutputState:
last_output: Optional[SamplerOutput] = None
@dataclass
class SchedulerContext:
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
SchedulerOutputs]] = field(
default_factory=lambda: deque())
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = field(
default_factory=lambda: [])
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
@ -350,9 +362,11 @@ class LLMEngine:
Scheduler(
scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size,
self._process_model_outputs
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=True)
if model_config.use_async_output_proc else None)
for _ in range(parallel_config.pipeline_parallel_size)
for v_id in range(parallel_config.pipeline_parallel_size)
]
# Metric Logging.
@ -406,12 +420,17 @@ class LLMEngine:
for _ in range(self.parallel_config.pipeline_parallel_size)
]
# Async output processing pointers
self.output_queue: Deque[Tuple[List[SamplerOutput],
List[SequenceGroupMetadata],
SchedulerOutputs]] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
self.scheduler_contexts = [
SchedulerContext()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callback = [
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=True)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@ -1221,32 +1240,28 @@ class LLMEngine:
return
def _process_model_outputs(self,
is_async: bool,
clear_outputs: bool = True) -> None:
def _process_model_outputs(self, virtual_engine: int,
is_async: bool) -> None:
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
clear_outputs: Sometimes existing outputs need to be combined
with outputs of this call. This happens for postprocessor
draining at the final stage (like when sequences are finished)
Returns RequestOutputs that can be returned to the client.
"""
now = time.time()
if clear_outputs:
self.request_outputs.clear()
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
if len(self.output_queue) == 0:
if len(ctx.output_queue) == 0:
return None
(outputs, seq_group_metadata_list,
scheduler_outputs) = self.output_queue.popleft()
scheduler_outputs) = ctx.output_queue.popleft()
# Sanity check
assert len(seq_group_metadata_list) == len(
@ -1321,11 +1336,11 @@ class LLMEngine:
if (seq_group.is_finished()
if self.step_return_finished_only else True):
request_output = RequestOutputFactory.create(seq_group)
self.request_outputs.append(request_output)
ctx.request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
self.request_outputs.append(request_output)
ctx.request_outputs.append(request_output)
if is_async:
# Log stats.
@ -1421,29 +1436,43 @@ class LLMEngine:
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0
virtual_engine = 0
# These are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[0]
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
ctx = self.scheduler_contexts[virtual_engine]
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc) = self.scheduler[0].schedule()
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
0, seq_group_metadata_list, scheduler_outputs,
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None
@ -1454,14 +1483,14 @@ class LLMEngine:
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
0].get_and_reset_finished_requests_ids()
virtual_engine].get_and_reset_finished_requests_ids()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(0)
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
@ -1476,20 +1505,24 @@ class LLMEngine:
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
execute_model_req.async_callback = self.async_callback[
virtual_engine]
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
# we need to do this here so that last step's sampled_token_ids can
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(0, output)
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(self.output_queue) > 0:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
# No outputs in this case
output = []
# Finish the current step for all the sequence groups.
@ -1504,7 +1537,7 @@ class LLMEngine:
# Add results to the output_queue
# (for async or non-async postprocessing)
self.output_queue.append(
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc:
@ -1515,8 +1548,10 @@ class LLMEngine:
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# Check if need to run the usual non-async path
if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
@ -1524,14 +1559,16 @@ class LLMEngine:
# Tracing
self.do_tracing(scheduler_outputs)
else:
self.request_outputs = []
# Multi-step case
ctx.request_outputs = []
if not self.has_unfinished_requests():
# Drain async postprocessor
if len(self.output_queue) > 0:
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True, clear_outputs=False)
assert len(self.output_queue) == 0
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
@ -1540,7 +1577,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()
return self.request_outputs
return ctx.request_outputs
def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]

View File

@ -811,6 +811,9 @@ class SequenceGroup:
self.is_single_seq = len(self.seqs) == 1
def is_finished(self) -> bool:
if self.is_single_seq:
return self.seqs[0].is_finished()
return all(seq.is_finished() for seq in self.seqs)
def is_prefill(self) -> bool:
@ -1290,8 +1293,8 @@ class ExecuteModelRequest(
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async postprocessor
output_proc_callback_fn: Optional[Callable] = None
# Async callback
async_callback: Optional[Callable] = None
@property
def is_first_multi_step(self) -> bool:
@ -1338,4 +1341,4 @@ class ExecuteModelRequest(
finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None,
output_proc_callback_fn=self.output_proc_callback_fn)
async_callback=self.async_callback)

View File

@ -91,7 +91,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
output_proc_callback_fn: Optional[Callable] = None
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
@ -1457,8 +1457,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if not self.is_driver_worker:
return []
if model_input.output_proc_callback_fn is not None:
model_input.output_proc_callback_fn(is_async=True)
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.model.sample(

View File

@ -263,11 +263,10 @@ class LocalOrDistributedWorkerBase(WorkerBase):
broadcast_data.update(kwargs)
broadcast_tensor_dict(broadcast_data, src=0)
if execute_model_req.output_proc_callback_fn:
if execute_model_req.async_callback:
model_input = dataclasses.replace( # type: ignore
model_input,
output_proc_callback_fn=execute_model_req.
output_proc_callback_fn)
async_callback=execute_model_req.async_callback)
return model_input, worker_input, kwargs