From 66a9e713a70d71070c77342b85e020ce536f13c0 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 20 Aug 2024 17:37:39 -0700 Subject: [PATCH] [Core] Pipe `worker_class_fn` argument in Executor (#7707) --- vllm/executor/gpu_executor.py | 26 +++++++++++++++++--------- vllm/executor/ray_gpu_executor.py | 5 +++-- vllm/executor/xpu_executor.py | 9 ++++++--- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index af426e31..8346c3cc 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase logger = init_logger(__name__) -def create_worker(worker_module_name, worker_class_name, **kwargs): +def create_worker(worker_module_name: str, worker_class_name: str, + worker_class_fn: Optional[Callable[[], Type[WorkerBase]]], + **kwargs): wrapper = WorkerWrapperBase( worker_module_name=worker_module_name, worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, ) wrapper.init_worker(**kwargs) return wrapper.worker @@ -62,7 +65,9 @@ class GPUExecutor(ExecutorBase): observability_config=self.observability_config, ) - def _get_worker_module_and_class(self) -> Tuple[str, str]: + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_class_fn = None if self.scheduler_config.is_multi_step: worker_module_name = "vllm.worker.multi_step_worker" worker_class_name = "MultiStepWorker" @@ -72,7 +77,7 @@ class GPUExecutor(ExecutorBase): else: worker_module_name = "vllm.worker.worker" worker_class_name = "Worker" - return (worker_module_name, worker_class_name) + return (worker_module_name, worker_class_name, worker_class_fn) def _get_create_worker_kwargs( self, @@ -82,10 +87,13 @@ class GPUExecutor(ExecutorBase): worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - (worker_module_name, - worker_class_name) = self._get_worker_module_and_class() - worker_kwargs.update(worker_module_name=worker_module_name, - worker_class_name=worker_class_name) + (worker_module_name, worker_class_name, + worker_class_fn) = self._get_worker_module_and_class() + worker_kwargs.update( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, + ) return worker_kwargs diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index bddb9521..aec6998d 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -91,12 +91,13 @@ class RayGPUExecutor(DistributedGPUExecutor): return ray_remote_kwargs def _get_worker_wrapper_args(self) -> Dict[str, Any]: - (worker_module_name, - worker_class_name) = self._get_worker_module_and_class() + (worker_module_name, worker_class_name, + worker_class_fn) = self._get_worker_module_and_class() return dict( worker_module_name=worker_module_name, worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, trust_remote_code=self.model_config.trust_remote_code, ) diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 45c8a3db..774204dd 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Type, Union import torch @@ -11,6 +11,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import make_async +from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -52,14 +53,16 @@ class XPUExecutor(GPUExecutor): # Instantiate the worker and load the model to GPU. self._init_executor() - def _get_worker_module_and_class(self) -> Tuple[str, str]: + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_class_fn = None if self.speculative_config is not None: raise NotImplementedError( "XPU does not support speculative decoding") else: worker_module_name = "vllm.worker.xpu_worker" worker_class_name = "XPUWorker" - return (worker_module_name, worker_class_name) + return (worker_module_name, worker_class_name, worker_class_fn) def execute_model( self, execute_model_req: ExecuteModelRequest