Don't use cupy NCCL for AMD backends (#2855)

This commit is contained in:
Woosuk Kwon 2024-02-14 12:30:44 -08:00 committed by GitHub
parent 4efbac6d35
commit 25e86b6a61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 7 deletions

View File

@ -67,6 +67,10 @@ def get_handle() -> Optional["CustomAllreduce"]:
return _CA_HANDLE
def is_initialized() -> bool:
return _CA_HANDLE is not None
@contextmanager
def capture():
try:

View File

@ -1,3 +1,4 @@
import contextlib
import time
from typing import Dict, List, Optional, Tuple, Set, Union
@ -9,9 +10,9 @@ from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.cupy_utils import get_nccl_backend
from vllm.model_executor.parallel_utils.parallel_state import (
with_cupy_nccl_for_all_reduce)
from vllm.model_executor.parallel_utils import custom_all_reduce
@ -659,7 +660,7 @@ class ModelRunner:
def capture_model(self, kv_caches: List[KVCache]) -> None:
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
self.cupy_nccl_backend = get_nccl_backend()
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to "
@ -689,8 +690,6 @@ class ModelRunner:
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
@ -698,6 +697,8 @@ class ModelRunner:
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or CuPy NCCL if it is disabled or not supported.
with custom_all_reduce.capture():
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
# Create dummy input_metadata.
input_metadata = InputMetadata(
@ -765,7 +766,7 @@ class CUDAGraphRunner:
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with with_cupy_nccl_for_all_reduce():
with _maybe_cupy_nccl():
self.model(
input_ids,
positions,
@ -779,7 +780,7 @@ class CUDAGraphRunner:
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with with_cupy_nccl_for_all_reduce():
with _maybe_cupy_nccl():
hidden_states = self.model(
input_ids,
positions,
@ -830,6 +831,15 @@ class CUDAGraphRunner:
return self.forward(*args, **kwargs)
@contextlib.contextmanager
def _maybe_cupy_nccl():
if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
with with_cupy_nccl_for_all_reduce():
yield
else:
yield
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))

View File

@ -19,6 +19,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip
class Worker:
@ -268,7 +269,8 @@ def init_distributed_environment(
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif parallel_config.world_size > 1 and cupy_port is not None:
elif (parallel_config.world_size > 1 and cupy_port is not None
and not is_hip()):
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.