[Core] Pipe worker_class_fn argument in Executor (#7707)
This commit is contained in:
parent
9e51b6a626
commit
66a9e713a7
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
|||||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
make_async)
|
make_async)
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_worker(worker_module_name, worker_class_name, **kwargs):
|
def create_worker(worker_module_name: str, worker_class_name: str,
|
||||||
|
worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
|
||||||
|
**kwargs):
|
||||||
wrapper = WorkerWrapperBase(
|
wrapper = WorkerWrapperBase(
|
||||||
worker_module_name=worker_module_name,
|
worker_module_name=worker_module_name,
|
||||||
worker_class_name=worker_class_name,
|
worker_class_name=worker_class_name,
|
||||||
|
worker_class_fn=worker_class_fn,
|
||||||
)
|
)
|
||||||
wrapper.init_worker(**kwargs)
|
wrapper.init_worker(**kwargs)
|
||||||
return wrapper.worker
|
return wrapper.worker
|
||||||
@ -62,7 +65,9 @@ class GPUExecutor(ExecutorBase):
|
|||||||
observability_config=self.observability_config,
|
observability_config=self.observability_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_worker_module_and_class(self) -> Tuple[str, str]:
|
def _get_worker_module_and_class(
|
||||||
|
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
|
||||||
|
worker_class_fn = None
|
||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
worker_module_name = "vllm.worker.multi_step_worker"
|
worker_module_name = "vllm.worker.multi_step_worker"
|
||||||
worker_class_name = "MultiStepWorker"
|
worker_class_name = "MultiStepWorker"
|
||||||
@ -72,7 +77,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
else:
|
else:
|
||||||
worker_module_name = "vllm.worker.worker"
|
worker_module_name = "vllm.worker.worker"
|
||||||
worker_class_name = "Worker"
|
worker_class_name = "Worker"
|
||||||
return (worker_module_name, worker_class_name)
|
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||||
|
|
||||||
def _get_create_worker_kwargs(
|
def _get_create_worker_kwargs(
|
||||||
self,
|
self,
|
||||||
@ -82,10 +87,13 @@ class GPUExecutor(ExecutorBase):
|
|||||||
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
|
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
|
||||||
distributed_init_method)
|
distributed_init_method)
|
||||||
|
|
||||||
(worker_module_name,
|
(worker_module_name, worker_class_name,
|
||||||
worker_class_name) = self._get_worker_module_and_class()
|
worker_class_fn) = self._get_worker_module_and_class()
|
||||||
worker_kwargs.update(worker_module_name=worker_module_name,
|
worker_kwargs.update(
|
||||||
worker_class_name=worker_class_name)
|
worker_module_name=worker_module_name,
|
||||||
|
worker_class_name=worker_class_name,
|
||||||
|
worker_class_fn=worker_class_fn,
|
||||||
|
)
|
||||||
|
|
||||||
return worker_kwargs
|
return worker_kwargs
|
||||||
|
|
||||||
|
|||||||
@ -91,12 +91,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
return ray_remote_kwargs
|
return ray_remote_kwargs
|
||||||
|
|
||||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||||
(worker_module_name,
|
(worker_module_name, worker_class_name,
|
||||||
worker_class_name) = self._get_worker_module_and_class()
|
worker_class_fn) = self._get_worker_module_and_class()
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
worker_module_name=worker_module_name,
|
worker_module_name=worker_module_name,
|
||||||
worker_class_name=worker_class_name,
|
worker_class_name=worker_class_name,
|
||||||
|
worker_class_fn=worker_class_fn,
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -11,6 +11,7 @@ from vllm.executor.gpu_executor import GPUExecutor
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -52,14 +53,16 @@ class XPUExecutor(GPUExecutor):
|
|||||||
# Instantiate the worker and load the model to GPU.
|
# Instantiate the worker and load the model to GPU.
|
||||||
self._init_executor()
|
self._init_executor()
|
||||||
|
|
||||||
def _get_worker_module_and_class(self) -> Tuple[str, str]:
|
def _get_worker_module_and_class(
|
||||||
|
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
|
||||||
|
worker_class_fn = None
|
||||||
if self.speculative_config is not None:
|
if self.speculative_config is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"XPU does not support speculative decoding")
|
"XPU does not support speculative decoding")
|
||||||
else:
|
else:
|
||||||
worker_module_name = "vllm.worker.xpu_worker"
|
worker_module_name = "vllm.worker.xpu_worker"
|
||||||
worker_class_name = "XPUWorker"
|
worker_class_name = "XPUWorker"
|
||||||
return (worker_module_name, worker_class_name)
|
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user