From ce741ba3e4fea00bacd2e1c609ca587ec35eb161 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Sun, 3 Sep 2023 21:43:43 -0700 Subject: [PATCH] Refactor AsyncLLMEngine (#880) --- vllm/core/scheduler.py | 13 +- vllm/engine/async_llm_engine.py | 329 +++++++++++++++++++++----------- vllm/engine/llm_engine.py | 78 ++++---- 3 files changed, 271 insertions(+), 149 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index db063d27..fc2335cc 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,6 +1,6 @@ import enum import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, Union from vllm.config import CacheConfig, SchedulerConfig from vllm.core.block_manager import BlockSpaceManager @@ -87,17 +87,22 @@ class Scheduler: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) - def abort_seq_group(self, request_id: str) -> None: + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) for state_queue in [self.waiting, self.running, self.swapped]: for seq_group in state_queue: - if seq_group.request_id == request_id: + if seq_group.request_id in request_ids: # Remove the sequence group from the state queue. state_queue.remove(seq_group) for seq in seq_group.seqs: if seq.is_finished(): continue self.free_seq(seq, SequenceStatus.FINISHED_ABORTED) - return + request_ids.remove(seq_group.request_id) + if not request_ids: + return def has_unfinished_seqs(self) -> bool: return self.waiting or self.running or self.swapped diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9049505d..54f38676 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,6 +1,7 @@ import asyncio import time -from typing import Dict, List, Optional +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -12,7 +13,105 @@ from vllm.sampling_params import SamplingParams logger = init_logger(__name__) -TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds + +class AsyncStream: + """A stream of RequestOutputs for a request that can be + iterated over asynchronously.""" + + def __init__(self, request_id: str) -> None: + self.request_id = request_id + self._queue = asyncio.Queue() + self._finished = False + + def put(self, item: RequestOutput) -> None: + if self._finished: + return + self._queue.put_nowait(item) + + def finish(self) -> None: + self._queue.put_nowait(StopIteration) + self._finished = True + + @property + def finished(self) -> bool: + return self._finished + + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._queue.get() + if result is StopIteration: + raise StopAsyncIteration + return result + + +def _raise_exception_on_finish(task: asyncio.Task) -> None: + try: + task.result() + except Exception as e: + raise RuntimeError("Task finished unexpectedly.") from e + raise RuntimeError("Task finished unexpectedly.") + + +class _AsyncLLMEngine(LLMEngine): + """Extension of LLMEngine to add async methods.""" + + async def step_async(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + (seq_group_metadata_list, scheduler_outputs, + early_return) = self._schedule() + if early_return is not None: + return early_return + + # Execute the model. + output = await self._run_workers_async( + "execute_model", + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + ) + + return self._process_worker_outputs(output, scheduler_outputs) + + async def _run_workers_async( + self, + method: str, + *args, + get_all_outputs: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + all_outputs = [] + for worker in self.workers: + if self.parallel_config.worker_use_ray: + executor = partial(worker.execute_method.remote, method) + else: + executor = getattr(worker, method) + + output = executor(*args, **kwargs) + all_outputs.append(output) + + if self.parallel_config.worker_use_ray: + all_outputs = await asyncio.gather(*all_outputs) + + if get_all_outputs: + return all_outputs + + # Make sure all workers have the same results. + output = all_outputs[0] + for other_output in all_outputs[1:]: + assert output == other_output + return output class AsyncLLMEngine: @@ -37,49 +136,111 @@ class AsyncLLMEngine: *args, *kwargs: Arguments for LLMEngine. """ + _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine + def __init__(self, worker_use_ray: bool, engine_use_ray: bool, *args, log_requests: bool = True, + start_engine_loop: bool = False, **kwargs) -> None: self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray self.log_requests = log_requests - if not self.engine_use_ray: - engine_class = LLMEngine - elif self.worker_use_ray: - engine_class = ray.remote(num_cpus=0)(LLMEngine).remote - else: - engine_class = ray.remote(num_gpus=1)(LLMEngine).remote - self.engine = engine_class(*args, **kwargs) - # Request id -> request output. - self.request_outputs: Dict[str, RequestOutput] = {} - # Request id -> event to notify that there is new output. - self.request_events: Dict[str, asyncio.Event] = {} - self.is_engine_running = False - self.kicking_request_id: Optional[str] = None + self.engine = self._init_engine(*args, **kwargs) - async def engine_step(self, kicking_request_id: Optional[str] = None): + # Request id -> stream. + self.request_streams: Dict[str, AsyncStream] = {} + self.finished_requests: Set[str] = set() + self.background_loop = None + if start_engine_loop: + self._start_background_loop() + + def _start_background_loop(self) -> None: + """Start the background loop.""" + if self.background_loop is not None: + raise RuntimeError("Background loop is already running.") + self.background_loop = asyncio.get_event_loop().create_task( + self.run_engine_loop()) + self.background_loop.add_done_callback(_raise_exception_on_finish) + + def _init_engine(self, *args, + **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: + if not self.engine_use_ray: + engine_class = self._engine_class + elif self.worker_use_ray: + engine_class = ray.remote(num_cpus=0)(self._engine_class).remote + else: + engine_class = ray.remote(num_gpus=1)(self._engine_class).remote + return engine_class(*args, **kwargs) + + async def engine_step(self): """Kick the engine to process the waiting requests.""" - self.is_engine_running = True - self.kicking_request_id = kicking_request_id if self.engine_use_ray: request_outputs = await self.engine.step.remote() else: - # Yield to the event loop to allow other coroutines to run - # while is_engine_running is True. This let the engine to add new - # requests into the queue. - await asyncio.sleep(0) - request_outputs = self.engine.step() - self.is_engine_running = False - self.kicking_request_id = None + request_outputs = await self.engine.step_async() - # Notify the waiting coroutines that there are new outputs ready. + # Put the outputs into the corresponding streams. for request_output in request_outputs: request_id = request_output.request_id - self.request_outputs[request_id] = request_output - self.request_events[request_id].set() + self.request_streams[request_id].put(request_output) + if request_output.finished: + if self.log_requests: + logger.info(f"Finished request {request_id}.") + self.request_streams[request_id].finish() + self.finished_requests.add(request_id) + + await self._engine_abort(self.finished_requests) + for request_id in self.finished_requests: + del self.request_streams[request_id] + self.finished_requests.clear() + + async def _engine_abort(self, request_ids: Iterable[str]): + if self.engine_use_ray: + await self.engine.abort_request.remote(request_ids) + else: + self.engine.abort_request(request_ids) + + async def run_engine_loop(self): + while True: + await self.engine_step() + await asyncio.sleep(0) + + async def add_request( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + ) -> AsyncStream: + if self.log_requests: + logger.info(f"Received request {request_id}: " + f"prompt: {prompt!r}, " + f"sampling params: {sampling_params}, " + f"prompt token ids: {prompt_token_ids}.") + + stream = AsyncStream(request_id) + self.request_streams[request_id] = stream + + # Add the request into the vLLM engine's waiting queue. + if self.engine_use_ray: + await self.engine.add_request.remote( + request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) + else: + self.engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) + + return stream async def generate( self, @@ -108,76 +269,19 @@ class AsyncLLMEngine: # Preprocess the request. arrival_time = time.time() - # Create an event to notify us that there is new output from the - # vLLM engine. - request_event = asyncio.Event() - self.request_events[request_id] = request_event + try: + stream = await self.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) - if self.log_requests: - logger.info(f"Received request {request_id}: " - f"prompt: {prompt!r}, " - f"sampling params: {sampling_params}, " - f"prompt token ids: {prompt_token_ids}.") - - # Add the request into the vLLM engine's waiting queue. - if self.engine_use_ray: - await self.engine.add_request.remote( - request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) - else: - self.engine.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) - - # The vLLM engine does not have a background loop that keeps - # processing incoming requests. Therefore, we need to keep kicking - # the engine to process the requests. - while True: - if request_id not in self.request_events: - # The request has been aborted. - return - - # Kick the engine if the engine is not running. - if not self.is_engine_running: - try: - await self.engine_step(request_id) - except RuntimeError as e: - await self.abort(request_id) - raise e - - # Wait for new output. The group_event will be set in engine_step - # when there is new output available for the sequence group. - # Added a timeout to prevent deadlock. - try: - await asyncio.wait_for(request_event.wait(), - timeout=TIMEOUT_TO_PREVENT_DEADLOCK) - except asyncio.TimeoutError: - continue - # Reset the event to wait for the next output. - request_event.clear() - - # Decode and return new outputs. - request_output = self.request_outputs[request_id] - yield request_output - - # Once finished, release the resources of the sequence group. - if request_output.finished: - if self.log_requests: - logger.info(f"Finished request {request_id}.") - - del self.request_outputs[request_id] - del self.request_events[request_id] - # Kick the engine if the engine is not running. This is to - # prevent that there are still requests in engine's waiting - # queue to be executed. - if not self.is_engine_running: - await self.engine_step() - break + async for request_output in stream: + yield request_output + except Exception as e: + # If there is an exception, abort the request. + self._abort(request_id) + raise e async def abort(self, request_id: str) -> None: """Abort a request. @@ -188,28 +292,27 @@ class AsyncLLMEngine: Args: request_id: The unique id of the request. """ - if request_id not in self.request_events: + return self._abort(request_id) + + def _abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + if request_id not in self.request_streams or self.request_streams[ + request_id].finished: # The request has already finished or been aborted. return if self.log_requests: logger.info(f"Aborted request {request_id}.") - if self.engine_use_ray: - await self.engine.abort_request.remote(request_id) - else: - self.engine.abort_request(request_id) - - if request_id in self.request_events: - del self.request_events[request_id] - if request_id in self.request_outputs: - del self.request_outputs[request_id] - - # To prevent deadlock when a request is aborted while the engine is - # running. - if self.kicking_request_id == request_id: - self.is_engine_running = False - self.kicking_request_id = None + self.request_streams[request_id].finish() + self.finished_requests.add(request_id) async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 908d01d9..54141bbe 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,17 +1,18 @@ -import time import copy +import time from functools import partial -from typing import Any, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.core.scheduler import Scheduler +from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs -from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker +from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) from vllm.utils import Counter @@ -135,7 +136,8 @@ class LLMEngine: get_all_outputs=True, ) - def _init_workers_ray(self, placement_group: "PlacementGroup"): + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel @@ -150,6 +152,7 @@ class LLMEngine: scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True), + **ray_remote_kwargs, )(RayWorker).remote() self.workers.append(worker) @@ -268,11 +271,11 @@ class LLMEngine: # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) - def abort_request(self, request_id: str) -> None: - """Aborts a request with the given ID. + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. Args: - request_id: The ID of the request to abort. + request_id: The ID(s) of the request to abort. """ self.scheduler.abort_seq_group(request_id) @@ -288,35 +291,21 @@ class LLMEngine: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() - def step(self) -> List[RequestOutput]: - """Performs one decoding iteration and returns newly generated results. - - This function performs one decoding iteration of the engine. It first - schedules the sequences to be executed in the next iteration and the - token blocks to be swapped in/out/copy. Then, it executes the model - and updates the scheduler with the model outputs. Finally, it decodes - the sequences and returns the newly generated results. - """ + def _schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, + Optional[List[RequestOutput]]]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if scheduler_outputs.is_empty(): - if not scheduler_outputs.ignored_seq_groups: - # Nothing to do. - return [] - # If there are ignored seq groups, we need to return them as the - # request outputs. - return [ + return seq_group_metadata_list, scheduler_outputs, [ RequestOutput.from_seq_group(seq_group) for seq_group in scheduler_outputs.ignored_seq_groups ] + return seq_group_metadata_list, scheduler_outputs, None - # Execute the model. - output = self._run_workers( - "execute_model", - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - ) + def _process_worker_outputs( + self, output, + scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: # Update the scheduler with the model outputs. seq_groups = self.scheduler.update(output) @@ -339,6 +328,31 @@ class LLMEngine: scheduler_outputs.num_batched_tokens) return request_outputs + def step(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + (seq_group_metadata_list, scheduler_outputs, + early_return) = self._schedule() + if early_return is not None: + return early_return + + # Execute the model. + output = self._run_workers( + "execute_model", + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + ) + + return self._process_worker_outputs(output, scheduler_outputs) + def _log_system_stats( self, prompt_run: bool,