Don't use cupy NCCL for AMD backends (#2855)
This commit is contained in:
parent
4efbac6d35
commit
25e86b6a61
@ -67,6 +67,10 @@ def get_handle() -> Optional["CustomAllreduce"]:
|
||||
return _CA_HANDLE
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
return _CA_HANDLE is not None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def capture():
|
||||
try:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Set, Union
|
||||
|
||||
@ -9,9 +10,9 @@ from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils.cupy_utils import get_nccl_backend
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
with_cupy_nccl_for_all_reduce)
|
||||
from vllm.model_executor.parallel_utils import custom_all_reduce
|
||||
@ -659,7 +660,7 @@ class ModelRunner:
|
||||
def capture_model(self, kv_caches: List[KVCache]) -> None:
|
||||
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
|
||||
# deleted before the CUDA graphs.
|
||||
self.cupy_nccl_backend = get_nccl_backend()
|
||||
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
|
||||
|
||||
assert not self.model_config.enforce_eager
|
||||
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
||||
@ -689,8 +690,6 @@ class ModelRunner:
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
|
||||
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
|
||||
@ -698,6 +697,8 @@ class ModelRunner:
|
||||
# We always prioritize using custom all-reduce kernel but fall back
|
||||
# to PyTorch or CuPy NCCL if it is disabled or not supported.
|
||||
with custom_all_reduce.capture():
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
# Create dummy input_metadata.
|
||||
input_metadata = InputMetadata(
|
||||
@ -765,7 +766,7 @@ class CUDAGraphRunner:
|
||||
# Run the model once without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
with with_cupy_nccl_for_all_reduce():
|
||||
with _maybe_cupy_nccl():
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
@ -779,7 +780,7 @@ class CUDAGraphRunner:
|
||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
|
||||
with with_cupy_nccl_for_all_reduce():
|
||||
with _maybe_cupy_nccl():
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
@ -830,6 +831,15 @@ class CUDAGraphRunner:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _maybe_cupy_nccl():
|
||||
if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
|
||||
with with_cupy_nccl_for_all_reduce():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||
assert len(x) <= max_len
|
||||
return x + [pad] * (max_len - len(x))
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import is_hip
|
||||
|
||||
|
||||
class Worker:
|
||||
@ -268,7 +269,8 @@ def init_distributed_environment(
|
||||
"cupy.distributed is already initialized but the cupy world "
|
||||
"size does not match parallel_config.world_size "
|
||||
f"({cupy_world_size} vs. {parallel_config.world_size}).")
|
||||
elif parallel_config.world_size > 1 and cupy_port is not None:
|
||||
elif (parallel_config.world_size > 1 and cupy_port is not None
|
||||
and not is_hip()):
|
||||
# NOTE(woosuk): We don't initialize CuPy process group when world size
|
||||
# is 1.
|
||||
# TODO(woosuk): Support multi-node connection.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user