From 0ed646b7aa3b434b2fcb6f6b6e725570879cb89e Mon Sep 17 00:00:00 2001 From: Murali Andoorveedu <37849411+andoorve@users.noreply.github.com> Date: Wed, 3 Jul 2024 17:52:29 -0700 Subject: [PATCH] [Distributed][Core] Support Py39 and Py38 for PP (#6120) Signed-off-by: Muralidhar Andoorveedu --- vllm/executor/executor_base.py | 7 +------ vllm/executor/ray_gpu_executor.py | 9 +++++++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 2abb29c1..fc18dec0 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -123,12 +123,7 @@ class ExecutorAsyncBase(ExecutorBase): multimodal_config: Optional[MultiModalConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: - # This locks each pipeline parallel stage so multiple virtual engines - # can't execute on the same stage at the same time - self.pp_locks = [ - asyncio.Lock() - for _ in range(parallel_config.pipeline_parallel_size) - ] + self.pp_locks: Optional[List[asyncio.Lock]] = None super().__init__(model_config, cache_config, parallel_config, scheduler_config, device_config, load_config, diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index e0b9441a..bc7ef9cc 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -349,6 +349,15 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + if self.pp_locks is None: + # This locks each pipeline parallel stage so multiple virtual + # engines can't execute on the same stage at the same time + # We create the locks here to avoid creating them in the constructor + # which uses a different asyncio loop. + self.pp_locks = [ + asyncio.Lock() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] async def _run_task_with_lock(task, lock, *args, **kwargs): async with lock: