[Core][Distributed] use cpu/gloo to initialize pynccl (#4248)

This commit is contained in:
youkaichao 2024-04-23 18:32:19 -07:00 committed by GitHub
parent 79a268c4ab
commit 91f50a6fe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 93 additions and 71 deletions

View File

@ -5,6 +5,7 @@ import torch
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.utils import update_environment_variables
@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
for p in processes:
p.join()
for p in processes:
assert p.exitcode == 0
def update_env(fn):
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
def wrapped_fn(env):
update_environment_variables(env)
init_distributed_environment()
fn()
return wrapper
return wrapped_fn
@update_env
@worker_fn_wrapper
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
@ -53,7 +58,7 @@ def test_pynccl():
distributed_run(worker_fn, 2)
@update_env
@worker_fn_wrapper
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()

View File

@ -20,14 +20,15 @@
# variable in the code.
import ctypes
import datetime
import platform
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check
@ -59,6 +60,18 @@ except Exception as e:
ncclResult_t = ctypes.c_int
_c_ncclGetErrorString = nccl.ncclGetErrorString
_c_ncclGetErrorString.restype = ctypes.c_char_p
_c_ncclGetErrorString.argtypes = [ncclResult_t]
def NCCL_CHECK(result: ncclResult_t) -> None:
if result != 0:
error_str = _c_ncclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"NCCL error: {error_str}")
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
@ -68,8 +81,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def ncclGetVersion() -> str:
version = ctypes.c_int()
result = _c_ncclGetVersion(ctypes.byref(version))
assert result == 0
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
# something like 21903 --> "2.19.3"
version_str = str(version.value)
major = version_str[0].lstrip("0")
@ -91,8 +103,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId()
result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
assert result == 0
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
return unique_id
@ -199,66 +210,75 @@ class NCCLCommunicator:
def __init__(
self,
backend=None,
init_method=None,
timeout=datetime.timedelta(seconds=10),
world_size: int = -1,
rank: int = -1,
store=None,
group_name: str = "",
pg_options=None,
local_rank: int = -1,
group: Optional[ProcessGroup] = None,
device: Optional[Union[int, str, torch.device]] = None,
):
if not dist.is_initialized():
backend = backend or "nccl"
assert backend == 'nccl', (
"only use nccl backend for starting the NCCL communicator")
dist.init_process_group(backend=backend,
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=rank,
store=store,
group_name=group_name,
pg_options=pg_options)
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
if local_rank == -1:
local_rank = self.rank
self.local_rank = local_rank
# don't use these args, as they can be -1
# use `self.rank`, `self.local_rank` and `self.world_size` instead
del world_size, rank, local_rank
torch.cuda.set_device(self.local_rank)
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the NCCLCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"NCCLCommunicator should be attached to a non-NCCL group.")
self.group = group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
self.local_rank)
dist.broadcast(tensor, src=0)
byte_list = tensor.cpu().tolist()
tensor = torch.ByteTensor(list(self.unique_id.internal))
dist.broadcast(tensor, src=0, group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank)
assert result == 0
self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
if device is None:
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
current_device = torch.cuda.current_device()
try:
torch.cuda.set_device(device)
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)
def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
NCCL_CHECK(
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0
ctypes.c_void_p(stream.cuda_stream)))
def __del__(self):
# `dist` module might have been already destroyed

View File

@ -2,7 +2,7 @@ import contextlib
from typing import Optional
import torch
from torch.distributed import ReduceOp
from torch.distributed import ProcessGroup, ReduceOp
from vllm.logger import init_logger
@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass
def init_process_group(world_size: int,
rank: int,
init_method: str,
local_rank: int = -1) -> None:
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized()
global comm
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
comm = NCCLCommunicator(init_method=init_method,
world_size=world_size,
local_rank=local_rank,
rank=rank)
comm = NCCLCommunicator(group=group)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:

View File

@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
import os
from typing import Optional
import torch
@ -73,6 +74,11 @@ def init_distributed_environment(
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1 and distributed_init_method == "env://":
local_rank = int(os.environ['LOCAL_RANK'])
global _LOCAL_RANK
_LOCAL_RANK = local_rank

View File

@ -298,12 +298,9 @@ def init_worker_distributed_environment(
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
pynccl_utils.init_process_group(
world_size=parallel_config.world_size,
local_rank=local_rank,
rank=rank,
init_method=distributed_init_method,
)
# NOTE(kaichao): By default, pynccl will use information inside
# `parallel_state` for initialization.
pynccl_utils.init_process_group()
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)