[Core] Pipe worker_class_fn argument in Executor (#7707)

This commit is contained in:
Antoni Baum 2024-08-20 17:37:39 -07:00 committed by GitHub
parent 9e51b6a626
commit 66a9e713a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 14 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
logger = init_logger(__name__)
def create_worker(worker_module_name, worker_class_name, **kwargs):
def create_worker(worker_module_name: str, worker_class_name: str,
worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
**kwargs):
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
)
wrapper.init_worker(**kwargs)
return wrapper.worker
@ -62,7 +65,9 @@ class GPUExecutor(ExecutorBase):
observability_config=self.observability_config,
)
def _get_worker_module_and_class(self) -> Tuple[str, str]:
def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
if self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_worker"
worker_class_name = "MultiStepWorker"
@ -72,7 +77,7 @@ class GPUExecutor(ExecutorBase):
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
return (worker_module_name, worker_class_name)
return (worker_module_name, worker_class_name, worker_class_fn)
def _get_create_worker_kwargs(
self,
@ -82,10 +87,13 @@ class GPUExecutor(ExecutorBase):
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)
(worker_module_name,
worker_class_name) = self._get_worker_module_and_class()
worker_kwargs.update(worker_module_name=worker_module_name,
worker_class_name=worker_class_name)
(worker_module_name, worker_class_name,
worker_class_fn) = self._get_worker_module_and_class()
worker_kwargs.update(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
)
return worker_kwargs

View File

@ -91,12 +91,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
return ray_remote_kwargs
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
(worker_module_name,
worker_class_name) = self._get_worker_module_and_class()
(worker_module_name, worker_class_name,
worker_class_fn) = self._get_worker_module_and_class()
return dict(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
trust_remote_code=self.model_config.trust_remote_code,
)

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Type, Union
import torch
@ -11,6 +11,7 @@ from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__)
@ -52,14 +53,16 @@ class XPUExecutor(GPUExecutor):
# Instantiate the worker and load the model to GPU.
self._init_executor()
def _get_worker_module_and_class(self) -> Tuple[str, str]:
def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
if self.speculative_config is not None:
raise NotImplementedError(
"XPU does not support speculative decoding")
else:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
return (worker_module_name, worker_class_name)
return (worker_module_name, worker_class_name, worker_class_fn)
def execute_model(
self, execute_model_req: ExecuteModelRequest