From 25e86b6a616638cea9ce121a6c28c7b1d69615e7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 14 Feb 2024 12:30:44 -0800 Subject: [PATCH] Don't use cupy NCCL for AMD backends (#2855) --- .../parallel_utils/custom_all_reduce.py | 4 ++++ vllm/worker/model_runner.py | 22 ++++++++++++++----- vllm/worker/worker.py | 4 +++- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 628c1517..ce4c8d02 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -67,6 +67,10 @@ def get_handle() -> Optional["CustomAllreduce"]: return _CA_HANDLE +def is_initialized() -> bool: + return _CA_HANDLE is not None + + @contextmanager def capture(): try: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 065d5899..a27b7d9c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,3 +1,4 @@ +import contextlib import time from typing import Dict, List, Optional, Tuple, Set, Union @@ -9,9 +10,9 @@ from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata +from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.cupy_utils import get_nccl_backend from vllm.model_executor.parallel_utils.parallel_state import ( with_cupy_nccl_for_all_reduce) from vllm.model_executor.parallel_utils import custom_all_reduce @@ -659,7 +660,7 @@ class ModelRunner: def capture_model(self, kv_caches: List[KVCache]) -> None: # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. - self.cupy_nccl_backend = get_nccl_backend() + self.cupy_nccl_backend = cupy_utils.get_nccl_backend() assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " @@ -689,8 +690,6 @@ class ModelRunner: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use # either custom all-reduce kernel or CuPy NCCL. When not using CUDA @@ -698,6 +697,8 @@ class ModelRunner: # We always prioritize using custom all-reduce kernel but fall back # to PyTorch or CuPy NCCL if it is disabled or not supported. with custom_all_reduce.capture(): + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy input_metadata. input_metadata = InputMetadata( @@ -765,7 +766,7 @@ class CUDAGraphRunner: # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with with_cupy_nccl_for_all_reduce(): + with _maybe_cupy_nccl(): self.model( input_ids, positions, @@ -779,7 +780,7 @@ class CUDAGraphRunner: # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 - with with_cupy_nccl_for_all_reduce(): + with _maybe_cupy_nccl(): hidden_states = self.model( input_ids, positions, @@ -830,6 +831,15 @@ class CUDAGraphRunner: return self.forward(*args, **kwargs) +@contextlib.contextmanager +def _maybe_cupy_nccl(): + if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized(): + with with_cupy_nccl_for_all_reduce(): + yield + else: + yield + + def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len return x + [pad] * (max_len - len(x)) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c460e2e0..29e4b16f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -19,6 +19,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner from vllm.lora.request import LoRARequest +from vllm.utils import is_hip class Worker: @@ -268,7 +269,8 @@ def init_distributed_environment( "cupy.distributed is already initialized but the cupy world " "size does not match parallel_config.world_size " f"({cupy_world_size} vs. {parallel_config.world_size}).") - elif parallel_config.world_size > 1 and cupy_port is not None: + elif (parallel_config.world_size > 1 and cupy_port is not None + and not is_hip()): # NOTE(woosuk): We don't initialize CuPy process group when world size # is 1. # TODO(woosuk): Support multi-node connection.