[Core] Pipe worker_class_fn argument in Executor (#7707)
This commit is contained in:
parent
9e51b6a626
commit
66a9e713a7
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_worker(worker_module_name, worker_class_name, **kwargs):
|
||||
def create_worker(worker_module_name: str, worker_class_name: str,
|
||||
worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
|
||||
**kwargs):
|
||||
wrapper = WorkerWrapperBase(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
)
|
||||
wrapper.init_worker(**kwargs)
|
||||
return wrapper.worker
|
||||
@ -62,7 +65,9 @@ class GPUExecutor(ExecutorBase):
|
||||
observability_config=self.observability_config,
|
||||
)
|
||||
|
||||
def _get_worker_module_and_class(self) -> Tuple[str, str]:
|
||||
def _get_worker_module_and_class(
|
||||
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
|
||||
worker_class_fn = None
|
||||
if self.scheduler_config.is_multi_step:
|
||||
worker_module_name = "vllm.worker.multi_step_worker"
|
||||
worker_class_name = "MultiStepWorker"
|
||||
@ -72,7 +77,7 @@ class GPUExecutor(ExecutorBase):
|
||||
else:
|
||||
worker_module_name = "vllm.worker.worker"
|
||||
worker_class_name = "Worker"
|
||||
return (worker_module_name, worker_class_name)
|
||||
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||
|
||||
def _get_create_worker_kwargs(
|
||||
self,
|
||||
@ -82,10 +87,13 @@ class GPUExecutor(ExecutorBase):
|
||||
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method)
|
||||
|
||||
(worker_module_name,
|
||||
worker_class_name) = self._get_worker_module_and_class()
|
||||
worker_kwargs.update(worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name)
|
||||
(worker_module_name, worker_class_name,
|
||||
worker_class_fn) = self._get_worker_module_and_class()
|
||||
worker_kwargs.update(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
)
|
||||
|
||||
return worker_kwargs
|
||||
|
||||
|
||||
@ -91,12 +91,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
return ray_remote_kwargs
|
||||
|
||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||
(worker_module_name,
|
||||
worker_class_name) = self._get_worker_module_and_class()
|
||||
(worker_module_name, worker_class_name,
|
||||
worker_class_fn) = self._get_worker_module_and_class()
|
||||
|
||||
return dict(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||
from vllm.utils import make_async
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -52,14 +53,16 @@ class XPUExecutor(GPUExecutor):
|
||||
# Instantiate the worker and load the model to GPU.
|
||||
self._init_executor()
|
||||
|
||||
def _get_worker_module_and_class(self) -> Tuple[str, str]:
|
||||
def _get_worker_module_and_class(
|
||||
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
|
||||
worker_class_fn = None
|
||||
if self.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"XPU does not support speculative decoding")
|
||||
else:
|
||||
worker_module_name = "vllm.worker.xpu_worker"
|
||||
worker_class_name = "XPUWorker"
|
||||
return (worker_module_name, worker_class_name)
|
||||
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
|
||||
Loading…
Reference in New Issue
Block a user