From 258a2c58d08fc7a242556120877a89404861fbce Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 26 Apr 2024 21:14:26 -0700 Subject: [PATCH] [Core] Introduce `DistributedGPUExecutor` abstract class (#4348) --- vllm/executor/distributed_gpu_executor.py | 114 ++++++++++++++++++++++ vllm/executor/ray_gpu_executor.py | 94 ++---------------- 2 files changed, 122 insertions(+), 86 deletions(-) create mode 100644 vllm/executor/distributed_gpu_executor.py diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py new file mode 100644 index 00000000..9dccfa49 --- /dev/null +++ b/vllm/executor/distributed_gpu_executor.py @@ -0,0 +1,114 @@ +from abc import abstractmethod +from typing import Any, Dict, Optional, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput + +logger = init_logger(__name__) + + +class DistributedGPUExecutor(GPUExecutor): + """Abstract superclass of multi-GPU executor implementations.""" + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model(self, *args, **kwargs) -> SamplerOutput: + all_outputs = self._run_workers("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + +class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): + + @abstractmethod + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + async def execute_model_async(self, *args, **kwargs) -> SamplerOutput: + all_outputs = await self._run_workers_async("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6f72babe..10829848 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,12 +3,12 @@ import os import pickle from collections import defaultdict from itertools import islice, repeat -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -27,7 +27,7 @@ logger = init_logger(__name__) USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) -class RayGPUExecutor(ExecutorBase): +class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: assert (not self.speculative_config @@ -179,50 +179,9 @@ class RayGPUExecutor(ExecutorBase): self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - Tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -244,23 +203,6 @@ class RayGPUExecutor(ExecutorBase): output = all_outputs[0] return output - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> Set[int]: - return self._run_workers("list_loras") - def _run_workers( self, method: str, @@ -378,7 +320,7 @@ class RayGPUExecutor(ExecutorBase): f"Dead Workers: {dead_actors}. ") -class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): +class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -409,23 +351,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): all_outputs = await asyncio.gather(*coros) return all_outputs - - async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output