[Core] Async_output_proc: Add virtual engine support (towards pipeline parallel) (#7911)
This commit is contained in:
parent
51f86bf487
commit
f508e03e7f
@ -302,7 +302,7 @@ class Scheduler:
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
pipeline_parallel_size: int = 1,
|
||||
output_proc_callback_fn: Optional[Callable] = None,
|
||||
output_proc_callback: Optional[Callable] = None,
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
@ -376,8 +376,8 @@ class Scheduler:
|
||||
# iterations. I.e. since the output processing is lagged one step,
|
||||
# we cannot reuse the cached objects immediately when the schedule()
|
||||
# is called again, but only when schedule() is called the second time.
|
||||
self.output_proc_callback_fn = output_proc_callback_fn
|
||||
self.use_async_output_proc = self.output_proc_callback_fn is not None
|
||||
self.output_proc_callback = output_proc_callback
|
||||
self.use_async_output_proc = self.output_proc_callback is not None
|
||||
self.num_cache_iters = 2 if self.use_async_output_proc else 1
|
||||
|
||||
self.cache_id = 0
|
||||
@ -573,8 +573,8 @@ class Scheduler:
|
||||
seq_group):
|
||||
tmp = self.running
|
||||
self.running = orig_running
|
||||
assert self.output_proc_callback_fn is not None
|
||||
self.output_proc_callback_fn(is_async=True)
|
||||
assert self.output_proc_callback is not None
|
||||
self.output_proc_callback()
|
||||
self.running = tmp
|
||||
|
||||
while not self._can_append_slots(seq_group):
|
||||
@ -1091,7 +1091,6 @@ class Scheduler:
|
||||
no_beam_search = seq_group.sampling_params is None or (
|
||||
seq_group.sampling_params.best_of == 1
|
||||
and not seq_group.sampling_params.use_beam_search)
|
||||
|
||||
return no_beam_search
|
||||
|
||||
def schedule(
|
||||
|
||||
@ -279,10 +279,16 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
|
||||
# Clear outputs on scheduler iteration start
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
@ -290,8 +296,9 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
# If current scheduler iteration has no async postprocessor,
|
||||
# then we need first to drain the pending async postprocessor
|
||||
# before moving forward
|
||||
if not allow_async_output_proc and len(self.output_queue) > 0:
|
||||
self._process_model_outputs(is_async=True)
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
@ -332,8 +339,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
execute_model_req.output_proc_callback_fn = \
|
||||
self._process_model_outputs
|
||||
execute_model_req.async_callback = self.async_callback[
|
||||
virtual_engine]
|
||||
|
||||
# Execute the model.
|
||||
output = await self.model_executor.execute_model_async(
|
||||
@ -343,9 +350,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
else:
|
||||
if len(self.output_queue) > 0:
|
||||
if len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(is_async=True)
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
output = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
@ -360,7 +368,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
|
||||
# Cache results in engine
|
||||
self.output_queue.append(
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
@ -372,7 +380,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(is_async=False)
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=False)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
@ -381,9 +390,17 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self.do_tracing(scheduler_outputs)
|
||||
|
||||
else:
|
||||
self.request_outputs = []
|
||||
ctx.request_outputs = []
|
||||
|
||||
return self.request_outputs
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor (if exists)
|
||||
if len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
assert len(ctx.output_queue) == 0
|
||||
|
||||
return ctx.request_outputs
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
"""Stop the remote worker execution loop."""
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import functools
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
|
||||
Mapping, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
@ -88,6 +89,17 @@ class SchedulerOutputState:
|
||||
last_output: Optional[SamplerOutput] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerContext:
|
||||
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
|
||||
SchedulerOutputs]] = field(
|
||||
default_factory=lambda: deque())
|
||||
|
||||
request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = field(
|
||||
default_factory=lambda: [])
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
|
||||
@ -350,9 +362,11 @@ class LLMEngine:
|
||||
Scheduler(
|
||||
scheduler_config, cache_config, lora_config,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
self._process_model_outputs
|
||||
functools.partial(self._process_model_outputs,
|
||||
virtual_engine=v_id,
|
||||
is_async=True)
|
||||
if model_config.use_async_output_proc else None)
|
||||
for _ in range(parallel_config.pipeline_parallel_size)
|
||||
for v_id in range(parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Metric Logging.
|
||||
@ -406,12 +420,17 @@ class LLMEngine:
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Async output processing pointers
|
||||
self.output_queue: Deque[Tuple[List[SamplerOutput],
|
||||
List[SequenceGroupMetadata],
|
||||
SchedulerOutputs]] = deque()
|
||||
self.request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
self.scheduler_contexts = [
|
||||
SchedulerContext()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.async_callback = [
|
||||
functools.partial(self._process_model_outputs,
|
||||
virtual_engine=v_id,
|
||||
is_async=True)
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
@ -1221,32 +1240,28 @@ class LLMEngine:
|
||||
|
||||
return
|
||||
|
||||
def _process_model_outputs(self,
|
||||
is_async: bool,
|
||||
clear_outputs: bool = True) -> None:
|
||||
def _process_model_outputs(self, virtual_engine: int,
|
||||
is_async: bool) -> None:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
virtual_engine: The engine id to operate on
|
||||
is_async: Indicates whether this postprocessor runs in
|
||||
parallel with the GPU forward pass and is processing
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
clear_outputs: Sometimes existing outputs need to be combined
|
||||
with outputs of this call. This happens for postprocessor
|
||||
draining at the final stage (like when sequences are finished)
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if clear_outputs:
|
||||
self.request_outputs.clear()
|
||||
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
if len(self.output_queue) == 0:
|
||||
if len(ctx.output_queue) == 0:
|
||||
return None
|
||||
|
||||
(outputs, seq_group_metadata_list,
|
||||
scheduler_outputs) = self.output_queue.popleft()
|
||||
scheduler_outputs) = ctx.output_queue.popleft()
|
||||
|
||||
# Sanity check
|
||||
assert len(seq_group_metadata_list) == len(
|
||||
@ -1321,11 +1336,11 @@ class LLMEngine:
|
||||
if (seq_group.is_finished()
|
||||
if self.step_return_finished_only else True):
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
self.request_outputs.append(request_output)
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
self.request_outputs.append(request_output)
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
if is_async:
|
||||
# Log stats.
|
||||
@ -1421,29 +1436,43 @@ class LLMEngine:
|
||||
"Pipeline parallelism is only supported through AsyncLLMEngine "
|
||||
"as performance will be severely degraded otherwise.")
|
||||
|
||||
# For llm_engine, there is no pipeline parallel support, so the engine
|
||||
# used is always 0
|
||||
virtual_engine = 0
|
||||
|
||||
# These are cached outputs from previous iterations. None if on first
|
||||
# iteration
|
||||
cached_outputs = self.cached_scheduler_outputs[0]
|
||||
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
|
||||
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc) = self.scheduler[0].schedule()
|
||||
|
||||
if not allow_async_output_proc and len(self.output_queue) > 0:
|
||||
self._process_model_outputs(is_async=True)
|
||||
# Clear outputs on scheduler iteration start
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# Schedule iteration
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
0, seq_group_metadata_list, scheduler_outputs,
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
@ -1454,14 +1483,14 @@ class LLMEngine:
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
finished_requests_ids = self.scheduler[
|
||||
0].get_and_reset_finished_requests_ids()
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Check if we have a cached last_output from the previous iteration.
|
||||
# For supporting PP this is probably the best way to pass the
|
||||
# sampled_token_ids, as a separate broadcast over all the PP stages
|
||||
# will cause one virtual engine's microbatch to block the pipeline.
|
||||
last_sampled_token_ids = \
|
||||
self._get_last_sampled_token_ids(0)
|
||||
self._get_last_sampled_token_ids(virtual_engine)
|
||||
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
@ -1476,20 +1505,24 @@ class LLMEngine:
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
execute_model_req.output_proc_callback_fn = \
|
||||
self._process_model_outputs
|
||||
execute_model_req.async_callback = self.async_callback[
|
||||
virtual_engine]
|
||||
|
||||
output = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
|
||||
# we need to do this here so that last step's sampled_token_ids can
|
||||
# We need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(0, output)
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
else:
|
||||
if len(self.output_queue) > 0:
|
||||
# Nothing scheduled => If there is pending async postprocessor,
|
||||
# then finish it here.
|
||||
if len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(is_async=True)
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
# No outputs in this case
|
||||
output = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
@ -1504,7 +1537,7 @@ class LLMEngine:
|
||||
|
||||
# Add results to the output_queue
|
||||
# (for async or non-async postprocessing)
|
||||
self.output_queue.append(
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
@ -1515,8 +1548,10 @@ class LLMEngine:
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
# Check if need to run the usual non-async path
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(is_async=False)
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=False)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
@ -1524,14 +1559,16 @@ class LLMEngine:
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
else:
|
||||
self.request_outputs = []
|
||||
# Multi-step case
|
||||
ctx.request_outputs = []
|
||||
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor
|
||||
if len(self.output_queue) > 0:
|
||||
# Drain async postprocessor (if exists)
|
||||
if len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(is_async=True, clear_outputs=False)
|
||||
assert len(self.output_queue) == 0
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
assert len(ctx.output_queue) == 0
|
||||
|
||||
# Stop the execute model loop in parallel workers until there are
|
||||
# more requests to process. This avoids waiting indefinitely in
|
||||
@ -1540,7 +1577,7 @@ class LLMEngine:
|
||||
# queued control plane messages, such as add/remove lora adapters.
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
return self.request_outputs
|
||||
return ctx.request_outputs
|
||||
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
|
||||
@ -811,6 +811,9 @@ class SequenceGroup:
|
||||
self.is_single_seq = len(self.seqs) == 1
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
if self.is_single_seq:
|
||||
return self.seqs[0].is_finished()
|
||||
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
|
||||
def is_prefill(self) -> bool:
|
||||
@ -1290,8 +1293,8 @@ class ExecuteModelRequest(
|
||||
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
||||
# The last sampled token ids for multi step decoding.
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
# Async postprocessor
|
||||
output_proc_callback_fn: Optional[Callable] = None
|
||||
# Async callback
|
||||
async_callback: Optional[Callable] = None
|
||||
|
||||
@property
|
||||
def is_first_multi_step(self) -> bool:
|
||||
@ -1338,4 +1341,4 @@ class ExecuteModelRequest(
|
||||
finished_requests_ids=self.finished_requests_ids,
|
||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||
if self.last_sampled_token_ids is not None else None,
|
||||
output_proc_callback_fn=self.output_proc_callback_fn)
|
||||
async_callback=self.async_callback)
|
||||
|
||||
@ -91,7 +91,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
virtual_engine: int = 0
|
||||
output_proc_callback_fn: Optional[Callable] = None
|
||||
async_callback: Optional[Callable] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
@ -1457,8 +1457,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.output_proc_callback_fn is not None:
|
||||
model_input.output_proc_callback_fn(is_async=True)
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
|
||||
@ -263,11 +263,10 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
broadcast_data.update(kwargs)
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
if execute_model_req.output_proc_callback_fn:
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace( # type: ignore
|
||||
model_input,
|
||||
output_proc_callback_fn=execute_model_req.
|
||||
output_proc_callback_fn)
|
||||
async_callback=execute_model_req.async_callback)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user