[Core][Distributed] use cpu/gloo to initialize pynccl (#4248)
This commit is contained in:
parent
79a268c4ab
commit
91f50a6fe2
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||
ncclGetUniqueId)
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
def update_env(fn):
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
# `multiprocessing.Process` cannot accept environment variables directly
|
||||
# so we need to pass the environment variables as arguments
|
||||
# and update the environment variables in the function
|
||||
def wrapper(env):
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
init_distributed_environment()
|
||||
fn()
|
||||
|
||||
return wrapper
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
@update_env
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
comm = NCCLCommunicator()
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
|
||||
@ -53,7 +58,7 @@ def test_pynccl():
|
||||
distributed_run(worker_fn, 2)
|
||||
|
||||
|
||||
@update_env
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_with_cudagraph():
|
||||
with torch.no_grad():
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
@ -20,14 +20,15 @@
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import datetime
|
||||
import platform
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_nccl_library, nccl_integrity_check
|
||||
|
||||
@ -59,6 +60,18 @@ except Exception as e:
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
|
||||
_c_ncclGetErrorString = nccl.ncclGetErrorString
|
||||
_c_ncclGetErrorString.restype = ctypes.c_char_p
|
||||
_c_ncclGetErrorString.argtypes = [ncclResult_t]
|
||||
|
||||
|
||||
def NCCL_CHECK(result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = _c_ncclGetErrorString(result)
|
||||
error_str = error_str.decode("utf-8")
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
_c_ncclGetVersion = nccl.ncclGetVersion
|
||||
@ -68,8 +81,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
||||
|
||||
def ncclGetVersion() -> str:
|
||||
version = ctypes.c_int()
|
||||
result = _c_ncclGetVersion(ctypes.byref(version))
|
||||
assert result == 0
|
||||
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
|
||||
# something like 21903 --> "2.19.3"
|
||||
version_str = str(version.value)
|
||||
major = version_str[0].lstrip("0")
|
||||
@ -91,8 +103,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
|
||||
|
||||
def ncclGetUniqueId() -> NcclUniqueId:
|
||||
unique_id = NcclUniqueId()
|
||||
result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
|
||||
assert result == 0
|
||||
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
|
||||
@ -199,66 +210,75 @@ class NCCLCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend=None,
|
||||
init_method=None,
|
||||
timeout=datetime.timedelta(seconds=10),
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
store=None,
|
||||
group_name: str = "",
|
||||
pg_options=None,
|
||||
local_rank: int = -1,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[Union[int, str, torch.device]] = None,
|
||||
):
|
||||
if not dist.is_initialized():
|
||||
backend = backend or "nccl"
|
||||
assert backend == 'nccl', (
|
||||
"only use nccl backend for starting the NCCL communicator")
|
||||
dist.init_process_group(backend=backend,
|
||||
init_method=init_method,
|
||||
timeout=timeout,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
store=store,
|
||||
group_name=group_name,
|
||||
pg_options=pg_options)
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
if local_rank == -1:
|
||||
local_rank = self.rank
|
||||
self.local_rank = local_rank
|
||||
# don't use these args, as they can be -1
|
||||
# use `self.rank`, `self.local_rank` and `self.world_size` instead
|
||||
del world_size, rank, local_rank
|
||||
torch.cuda.set_device(self.local_rank)
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the NCCLCommunicator to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
assert dist.is_initialized()
|
||||
group = get_cpu_world_group() if group is None else group
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"NCCLCommunicator should be attached to a non-NCCL group.")
|
||||
self.group = group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
if self.rank == 0:
|
||||
self.unique_id = ncclGetUniqueId()
|
||||
else:
|
||||
self.unique_id = NcclUniqueId()
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
|
||||
self.local_rank)
|
||||
dist.broadcast(tensor, src=0)
|
||||
byte_list = tensor.cpu().tolist()
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
dist.broadcast(tensor, src=0, group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
self.comm = ctypes.c_void_p()
|
||||
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||
self.unique_id, self.rank)
|
||||
assert result == 0
|
||||
self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
|
||||
if device is None:
|
||||
local_rank = get_local_rank()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
elif isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
# nccl communicator and stream will use this device
|
||||
current_device = torch.cuda.current_device()
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
NCCL_CHECK(
|
||||
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||
self.unique_id, self.rank))
|
||||
self.stream = torch.cuda.Stream()
|
||||
finally:
|
||||
torch.cuda.set_device(current_device)
|
||||
|
||||
def all_reduce(self,
|
||||
tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None):
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
||||
NCCL_CHECK(
|
||||
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
||||
ctypes.c_void_p(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
ctypes.c_void_p(stream.cuda_stream))
|
||||
assert result == 0
|
||||
ctypes.c_void_p(stream.cuda_stream)))
|
||||
|
||||
def __del__(self):
|
||||
# `dist` module might have been already destroyed
|
||||
|
||||
@ -2,7 +2,7 @@ import contextlib
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
|
||||
pass
|
||||
|
||||
|
||||
def init_process_group(world_size: int,
|
||||
rank: int,
|
||||
init_method: str,
|
||||
local_rank: int = -1) -> None:
|
||||
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
|
||||
assert not is_initialized()
|
||||
global comm
|
||||
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
|
||||
comm = NCCLCommunicator(init_method=init_method,
|
||||
world_size=world_size,
|
||||
local_rank=local_rank,
|
||||
rank=rank)
|
||||
comm = NCCLCommunicator(group=group)
|
||||
|
||||
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
"""Tensor and pipeline parallel groups."""
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -73,6 +74,11 @@ def init_distributed_environment(
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
|
||||
backend="gloo")
|
||||
# set the local rank
|
||||
# local_rank is not available in torch ProcessGroup,
|
||||
# see https://github.com/pytorch/pytorch/issues/122816
|
||||
if local_rank == -1 and distributed_init_method == "env://":
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
global _LOCAL_RANK
|
||||
_LOCAL_RANK = local_rank
|
||||
|
||||
|
||||
@ -298,12 +298,9 @@ def init_worker_distributed_environment(
|
||||
elif parallel_config.world_size > 1:
|
||||
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
||||
# is 1.
|
||||
pynccl_utils.init_process_group(
|
||||
world_size=parallel_config.world_size,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
# NOTE(kaichao): By default, pynccl will use information inside
|
||||
# `parallel_state` for initialization.
|
||||
pynccl_utils.init_process_group()
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user