[Core] Some simplification of WorkerWrapper changes (#4183)

This commit is contained in:
Nick Hill 2024-04-23 00:49:08 -07:00 committed by GitHub
parent 0ae11f78ab
commit 8f2ea22bde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 54 deletions

View File

@ -2,6 +2,7 @@ import asyncio
import os import os
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from vllm.engine.ray_utils import RayWorkerWrapper, ray from vllm.engine.ray_utils import RayWorkerWrapper, ray
@ -136,16 +137,14 @@ class RayGPUExecutor(ExecutorBase):
VLLM_INSTANCE_ID = get_vllm_instance_id() VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers. # Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [] all_args_to_update_environment_variables = [({
for (node_id, _) in worker_node_and_gpu_ids:
all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES": "CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])), ",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID": "VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID, VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION": "VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"), os.getenv("VLLM_TRACE_FUNCTION", "0"),
}]) }, ) for (node_id, _) in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables", self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables) all_args=all_args_to_update_environment_variables)
@ -156,10 +155,9 @@ class RayGPUExecutor(ExecutorBase):
# avoid writing `{"name": value}` manually # avoid writing `{"name": value}` manually
return kwargs return kwargs
init_worker_all_kwargs = []
# Initialize the actual workers inside worker wrapper. # Initialize the actual workers inside worker wrapper.
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ): init_worker_all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank) local_rank = node_workers[node_id].index(rank)
init_worker_all_kwargs.append( init_worker_all_kwargs.append(
collect_arg_helper_func( collect_arg_helper_func(
@ -265,40 +263,40 @@ class RayGPUExecutor(ExecutorBase):
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[Tuple[Any]] = None, driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[List[Any]]] = None, all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False, use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False, use_ray_compiled_dag: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers. """Runs the given method on all workers. Can be used in the following
all_args and all_kwargs are used to pass heterogeneous arguments, ways:
i.e. different arguments for each worker.
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
""" """
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# for mypy type checking
assert driver_args is not None
assert driver_kwargs is not None
if all_args is None:
all_args = [driver_args] + [args] * len(self.workers)
if all_kwargs is None:
all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)
# for mypy type checking
assert all_args is not None
assert all_kwargs is not None
if max_concurrent_workers: if max_concurrent_workers:
raise NotImplementedError( raise NotImplementedError(
"max_concurrent_workers is not supported yet.") "max_concurrent_workers is not supported yet.")
if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
if use_ray_compiled_dag: if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single # Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it. # input. TODO(sang): Fix it.
@ -310,22 +308,17 @@ class RayGPUExecutor(ExecutorBase):
worker.execute_method.remote(method, *worker_args, worker.execute_method.remote(method, *worker_args,
**worker_kwargs) **worker_kwargs)
for (worker, worker_args, worker_kwargs for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_args[1:], all_kwargs[1:]) ) in zip(self.workers, all_worker_args, all_worker_kwargs)
] ]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Start the driver worker after all the ray workers. # Start the driver worker after all the ray workers.
if not use_dummy_driver: if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method( driver_worker_output = self.driver_worker.execute_method(
method, *all_args[0], **all_kwargs[0]) method, *driver_args, **driver_kwargs)
else: else:
driver_worker_output = ray.get( driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote( self.driver_dummy_worker.execute_method.remote(
method, *all_args[0], **all_kwargs[0])) method, *driver_args, **driver_kwargs))
# Get the results of the ray workers. # Get the results of the ray workers.
if self.workers: if self.workers:
if use_ray_compiled_dag: if use_ray_compiled_dag:
@ -383,6 +376,10 @@ class RayGPUExecutor(ExecutorBase):
class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method)
async def _run_workers_async( async def _run_workers_async(
self, self,
method: str, method: str,
@ -399,13 +396,8 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
if driver_kwargs is None: if driver_kwargs is None:
driver_kwargs = kwargs driver_kwargs = kwargs
# Run the driver worker asynchronously. coros.append(
def helper(): self.driver_executor(method, *driver_args, **driver_kwargs))
return self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
driver_executor = make_async(helper)
coros.append(driver_executor())
# Run the ray workers asynchronously. # Run the ray workers asynchronously.
for worker in self.workers: for worker in self.workers:

View File

@ -108,7 +108,8 @@ class WorkerWrapperBase:
self.worker_class_name = worker_class_name self.worker_class_name = worker_class_name
self.worker = None self.worker = None
def update_environment_variables(self, envs: Dict[str, str]) -> None: @staticmethod
def update_environment_variables(envs: Dict[str, str]) -> None:
key = 'CUDA_VISIBLE_DEVICES' key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ: if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior # overwriting CUDA_VISIBLE_DEVICES is desired behavior
@ -138,10 +139,8 @@ class WorkerWrapperBase:
def execute_method(self, method, *args, **kwargs): def execute_method(self, method, *args, **kwargs):
try: try:
if hasattr(self, method): target = self if self.worker is None else self.worker
executor = getattr(self, method) executor = getattr(target, method)
else:
executor = getattr(self.worker, method)
return executor(*args, **kwargs) return executor(*args, **kwargs)
except Exception as e: except Exception as e:
# if the driver worker also execute methods, # if the driver worker also execute methods,