From 0229c386c541a293f18a9ffe1a5cd7735d487158 Mon Sep 17 00:00:00 2001 From: FlorianJoncour <148003496+FlorianJoncour@users.noreply.github.com> Date: Wed, 29 Nov 2023 21:25:43 +0000 Subject: [PATCH] Better integration with Ray Serve (#1821) Co-authored-by: FlorianJoncour --- vllm/engine/llm_engine.py | 6 +++--- vllm/engine/ray_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f79b5e84..db1ee606 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs -from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray +from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams @@ -162,12 +162,12 @@ class LLMEngine: continue worker = ray.remote( num_cpus=0, - num_gpus=1, + num_gpus=self.cache_config.gpu_memory_utilization, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True), **ray_remote_kwargs, - )(RayWorker).remote(self.model_config.trust_remote_code) + )(RayWorkerVllm).remote(self.model_config.trust_remote_code) self.workers.append(worker) # Initialize torch distributed process group for the workers. diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index ee58b8b9..04660926 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -10,7 +10,7 @@ try: import ray from ray.air.util.torch_dist import TorchDistributedWorker - class RayWorker(TorchDistributedWorker): + class RayWorkerVllm(TorchDistributedWorker): """Ray wrapper for vllm.worker.Worker, allowing Worker to be lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" @@ -36,7 +36,7 @@ except ImportError as e: "`pip install ray pandas pyarrow`.") ray = None TorchDistributedWorker = None - RayWorker = None + RayWorkerVllm = None if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup