[Core][Distributed] use cpu/gloo to initialize pynccl (#4248)

This commit is contained in:
youkaichao 2024-04-23 18:32:19 -07:00 committed by GitHub
parent 79a268c4ab
commit 91f50a6fe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 93 additions and 71 deletions

View File

@ -5,6 +5,7 @@ import torch
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId) ncclGetUniqueId)
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
for p in processes: for p in processes:
p.join() p.join()
for p in processes:
assert p.exitcode == 0
def update_env(fn):
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly # `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments # so we need to pass the environment variables as arguments
# and update the environment variables in the function # and update the environment variables in the function
def wrapper(env): def wrapped_fn(env):
update_environment_variables(env) update_environment_variables(env)
init_distributed_environment()
fn() fn()
return wrapper return wrapped_fn
@update_env @worker_fn_wrapper
def worker_fn(): def worker_fn():
comm = NCCLCommunicator() comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
@ -53,7 +58,7 @@ def test_pynccl():
distributed_run(worker_fn, 2) distributed_run(worker_fn, 2)
@update_env @worker_fn_wrapper
def worker_fn_with_cudagraph(): def worker_fn_with_cudagraph():
with torch.no_grad(): with torch.no_grad():
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()

View File

@ -20,14 +20,15 @@
# variable in the code. # variable in the code.
import ctypes import ctypes
import datetime
import platform import platform
from typing import Optional, Union
# ===================== import region ===================== # ===================== import region =====================
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check from vllm.utils import find_nccl_library, nccl_integrity_check
@ -59,6 +60,18 @@ except Exception as e:
ncclResult_t = ctypes.c_int ncclResult_t = ctypes.c_int
_c_ncclGetErrorString = nccl.ncclGetErrorString
_c_ncclGetErrorString.restype = ctypes.c_char_p
_c_ncclGetErrorString.argtypes = [ncclResult_t]
def NCCL_CHECK(result: ncclResult_t) -> None:
if result != 0:
error_str = _c_ncclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"NCCL error: {error_str}")
# equivalent to c declaration: # equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version); # ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion _c_ncclGetVersion = nccl.ncclGetVersion
@ -68,8 +81,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def ncclGetVersion() -> str: def ncclGetVersion() -> str:
version = ctypes.c_int() version = ctypes.c_int()
result = _c_ncclGetVersion(ctypes.byref(version)) NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
assert result == 0
# something like 21903 --> "2.19.3" # something like 21903 --> "2.19.3"
version_str = str(version.value) version_str = str(version.value)
major = version_str[0].lstrip("0") major = version_str[0].lstrip("0")
@ -91,8 +103,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def ncclGetUniqueId() -> NcclUniqueId: def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId() unique_id = NcclUniqueId()
result = _c_ncclGetUniqueId(ctypes.byref(unique_id)) NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
assert result == 0
return unique_id return unique_id
@ -199,66 +210,75 @@ class NCCLCommunicator:
def __init__( def __init__(
self, self,
backend=None, group: Optional[ProcessGroup] = None,
init_method=None, device: Optional[Union[int, str, torch.device]] = None,
timeout=datetime.timedelta(seconds=10),
world_size: int = -1,
rank: int = -1,
store=None,
group_name: str = "",
pg_options=None,
local_rank: int = -1,
): ):
if not dist.is_initialized(): """
backend = backend or "nccl" Args:
assert backend == 'nccl', ( group: the process group to work on. If None, it will use the
"only use nccl backend for starting the NCCL communicator") default process group.
dist.init_process_group(backend=backend, device: the device to bind the NCCLCommunicator to. If None,
init_method=init_method, it will be bind to f"cuda:{local_rank}".
timeout=timeout, It is the caller's responsibility to make sure each communicator
world_size=world_size, is bind to a unique device.
rank=rank, """
store=store, assert dist.is_initialized()
group_name=group_name, group = get_cpu_world_group() if group is None else group
pg_options=pg_options) assert dist.get_backend(group) != dist.Backend.NCCL, (
self.rank = dist.get_rank() "NCCLCommunicator should be attached to a non-NCCL group.")
self.world_size = dist.get_world_size() self.group = group
if local_rank == -1: self.rank = dist.get_rank(group)
local_rank = self.rank self.world_size = dist.get_world_size(group)
self.local_rank = local_rank
# don't use these args, as they can be -1
# use `self.rank`, `self.local_rank` and `self.world_size` instead
del world_size, rank, local_rank
torch.cuda.set_device(self.local_rank)
if self.rank == 0: if self.rank == 0:
self.unique_id = ncclGetUniqueId() self.unique_id = ncclGetUniqueId()
else: else:
self.unique_id = NcclUniqueId() self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( tensor = torch.ByteTensor(list(self.unique_id.internal))
self.local_rank) dist.broadcast(tensor, src=0, group=group)
dist.broadcast(tensor, src=0) byte_list = tensor.tolist()
byte_list = tensor.cpu().tolist()
for i, byte in enumerate(byte_list): for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p() self.comm = ctypes.c_void_p()
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, if device is None:
self.unique_id, self.rank) local_rank = get_local_rank()
assert result == 0 device = torch.device(f"cuda:{local_rank}")
self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}") elif isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
current_device = torch.cuda.current_device()
try:
torch.cuda.set_device(device)
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)
def all_reduce(self, def all_reduce(self,
tensor: torch.Tensor, tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
stream=None): stream=None):
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = self.stream
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), NCCL_CHECK(
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()), ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm, ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream)) ctypes.c_void_p(stream.cuda_stream)))
assert result == 0
def __del__(self): def __del__(self):
# `dist` module might have been already destroyed # `dist` module might have been already destroyed

View File

@ -2,7 +2,7 @@ import contextlib
from typing import Optional from typing import Optional
import torch import torch
from torch.distributed import ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from vllm.logger import init_logger from vllm.logger import init_logger
@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass pass
def init_process_group(world_size: int, def init_process_group(group: Optional[ProcessGroup] = None) -> None:
rank: int,
init_method: str,
local_rank: int = -1) -> None:
assert not is_initialized() assert not is_initialized()
global comm global comm
logger.info(f"vLLM is using nccl=={ncclGetVersion()}") logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
comm = NCCLCommunicator(init_method=init_method, comm = NCCLCommunicator(group=group)
world_size=world_size,
local_rank=local_rank,
rank=rank)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:

View File

@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups.""" """Tensor and pipeline parallel groups."""
import contextlib import contextlib
import os
from typing import Optional from typing import Optional
import torch import torch
@ -73,6 +74,11 @@ def init_distributed_environment(
ranks = list(range(torch.distributed.get_world_size())) ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo") backend="gloo")
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1 and distributed_init_method == "env://":
local_rank = int(os.environ['LOCAL_RANK'])
global _LOCAL_RANK global _LOCAL_RANK
_LOCAL_RANK = local_rank _LOCAL_RANK = local_rank

View File

@ -298,12 +298,9 @@ def init_worker_distributed_environment(
elif parallel_config.world_size > 1: elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size # NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1. # is 1.
pynccl_utils.init_process_group( # NOTE(kaichao): By default, pynccl will use information inside
world_size=parallel_config.world_size, # `parallel_state` for initialization.
local_rank=local_rank, pynccl_utils.init_process_group()
rank=rank,
init_method=distributed_init_method,
)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)