[Core][Distributed] use cpu/gloo to initialize pynccl (#4248)
This commit is contained in:
parent
79a268c4ab
commit
91f50a6fe2
@ -5,6 +5,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||||
ncclGetUniqueId)
|
ncclGetUniqueId)
|
||||||
|
from vllm.distributed.parallel_state import init_distributed_environment
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
|
|||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
|
for p in processes:
|
||||||
|
assert p.exitcode == 0
|
||||||
|
|
||||||
def update_env(fn):
|
|
||||||
|
def worker_fn_wrapper(fn):
|
||||||
# `multiprocessing.Process` cannot accept environment variables directly
|
# `multiprocessing.Process` cannot accept environment variables directly
|
||||||
# so we need to pass the environment variables as arguments
|
# so we need to pass the environment variables as arguments
|
||||||
# and update the environment variables in the function
|
# and update the environment variables in the function
|
||||||
def wrapper(env):
|
def wrapped_fn(env):
|
||||||
update_environment_variables(env)
|
update_environment_variables(env)
|
||||||
|
init_distributed_environment()
|
||||||
fn()
|
fn()
|
||||||
|
|
||||||
return wrapper
|
return wrapped_fn
|
||||||
|
|
||||||
|
|
||||||
@update_env
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
comm = NCCLCommunicator()
|
comm = NCCLCommunicator()
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
|
||||||
@ -53,7 +58,7 @@ def test_pynccl():
|
|||||||
distributed_run(worker_fn, 2)
|
distributed_run(worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@update_env
|
@worker_fn_wrapper
|
||||||
def worker_fn_with_cudagraph():
|
def worker_fn_with_cudagraph():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
|||||||
@ -20,14 +20,15 @@
|
|||||||
# variable in the code.
|
# variable in the code.
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import datetime
|
|
||||||
import platform
|
import platform
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
# ===================== import region =====================
|
# ===================== import region =====================
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ProcessGroup, ReduceOp
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import find_nccl_library, nccl_integrity_check
|
from vllm.utils import find_nccl_library, nccl_integrity_check
|
||||||
|
|
||||||
@ -59,6 +60,18 @@ except Exception as e:
|
|||||||
|
|
||||||
ncclResult_t = ctypes.c_int
|
ncclResult_t = ctypes.c_int
|
||||||
|
|
||||||
|
_c_ncclGetErrorString = nccl.ncclGetErrorString
|
||||||
|
_c_ncclGetErrorString.restype = ctypes.c_char_p
|
||||||
|
_c_ncclGetErrorString.argtypes = [ncclResult_t]
|
||||||
|
|
||||||
|
|
||||||
|
def NCCL_CHECK(result: ncclResult_t) -> None:
|
||||||
|
if result != 0:
|
||||||
|
error_str = _c_ncclGetErrorString(result)
|
||||||
|
error_str = error_str.decode("utf-8")
|
||||||
|
raise RuntimeError(f"NCCL error: {error_str}")
|
||||||
|
|
||||||
|
|
||||||
# equivalent to c declaration:
|
# equivalent to c declaration:
|
||||||
# ncclResult_t ncclGetVersion(int *version);
|
# ncclResult_t ncclGetVersion(int *version);
|
||||||
_c_ncclGetVersion = nccl.ncclGetVersion
|
_c_ncclGetVersion = nccl.ncclGetVersion
|
||||||
@ -68,8 +81,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
|||||||
|
|
||||||
def ncclGetVersion() -> str:
|
def ncclGetVersion() -> str:
|
||||||
version = ctypes.c_int()
|
version = ctypes.c_int()
|
||||||
result = _c_ncclGetVersion(ctypes.byref(version))
|
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
|
||||||
assert result == 0
|
|
||||||
# something like 21903 --> "2.19.3"
|
# something like 21903 --> "2.19.3"
|
||||||
version_str = str(version.value)
|
version_str = str(version.value)
|
||||||
major = version_str[0].lstrip("0")
|
major = version_str[0].lstrip("0")
|
||||||
@ -91,8 +103,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
|
|||||||
|
|
||||||
def ncclGetUniqueId() -> NcclUniqueId:
|
def ncclGetUniqueId() -> NcclUniqueId:
|
||||||
unique_id = NcclUniqueId()
|
unique_id = NcclUniqueId()
|
||||||
result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
|
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
|
||||||
assert result == 0
|
|
||||||
return unique_id
|
return unique_id
|
||||||
|
|
||||||
|
|
||||||
@ -199,66 +210,75 @@ class NCCLCommunicator:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backend=None,
|
group: Optional[ProcessGroup] = None,
|
||||||
init_method=None,
|
device: Optional[Union[int, str, torch.device]] = None,
|
||||||
timeout=datetime.timedelta(seconds=10),
|
|
||||||
world_size: int = -1,
|
|
||||||
rank: int = -1,
|
|
||||||
store=None,
|
|
||||||
group_name: str = "",
|
|
||||||
pg_options=None,
|
|
||||||
local_rank: int = -1,
|
|
||||||
):
|
):
|
||||||
if not dist.is_initialized():
|
"""
|
||||||
backend = backend or "nccl"
|
Args:
|
||||||
assert backend == 'nccl', (
|
group: the process group to work on. If None, it will use the
|
||||||
"only use nccl backend for starting the NCCL communicator")
|
default process group.
|
||||||
dist.init_process_group(backend=backend,
|
device: the device to bind the NCCLCommunicator to. If None,
|
||||||
init_method=init_method,
|
it will be bind to f"cuda:{local_rank}".
|
||||||
timeout=timeout,
|
It is the caller's responsibility to make sure each communicator
|
||||||
world_size=world_size,
|
is bind to a unique device.
|
||||||
rank=rank,
|
"""
|
||||||
store=store,
|
assert dist.is_initialized()
|
||||||
group_name=group_name,
|
group = get_cpu_world_group() if group is None else group
|
||||||
pg_options=pg_options)
|
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||||
self.rank = dist.get_rank()
|
"NCCLCommunicator should be attached to a non-NCCL group.")
|
||||||
self.world_size = dist.get_world_size()
|
self.group = group
|
||||||
if local_rank == -1:
|
self.rank = dist.get_rank(group)
|
||||||
local_rank = self.rank
|
self.world_size = dist.get_world_size(group)
|
||||||
self.local_rank = local_rank
|
|
||||||
# don't use these args, as they can be -1
|
|
||||||
# use `self.rank`, `self.local_rank` and `self.world_size` instead
|
|
||||||
del world_size, rank, local_rank
|
|
||||||
torch.cuda.set_device(self.local_rank)
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self.unique_id = ncclGetUniqueId()
|
self.unique_id = ncclGetUniqueId()
|
||||||
else:
|
else:
|
||||||
self.unique_id = NcclUniqueId()
|
self.unique_id = NcclUniqueId()
|
||||||
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
|
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||||
self.local_rank)
|
dist.broadcast(tensor, src=0, group=group)
|
||||||
dist.broadcast(tensor, src=0)
|
byte_list = tensor.tolist()
|
||||||
byte_list = tensor.cpu().tolist()
|
|
||||||
for i, byte in enumerate(byte_list):
|
for i, byte in enumerate(byte_list):
|
||||||
self.unique_id.internal[i] = byte
|
self.unique_id.internal[i] = byte
|
||||||
self.comm = ctypes.c_void_p()
|
self.comm = ctypes.c_void_p()
|
||||||
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
if device is None:
|
||||||
self.unique_id, self.rank)
|
local_rank = get_local_rank()
|
||||||
assert result == 0
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
|
elif isinstance(device, int):
|
||||||
|
device = torch.device(f"cuda:{device}")
|
||||||
|
elif isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
# now `device` is a `torch.device` object
|
||||||
|
assert isinstance(device, torch.device)
|
||||||
|
self.device = device
|
||||||
|
# nccl communicator and stream will use this device
|
||||||
|
current_device = torch.cuda.current_device()
|
||||||
|
try:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
NCCL_CHECK(
|
||||||
|
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||||
|
self.unique_id, self.rank))
|
||||||
|
self.stream = torch.cuda.Stream()
|
||||||
|
finally:
|
||||||
|
torch.cuda.set_device(current_device)
|
||||||
|
|
||||||
def all_reduce(self,
|
def all_reduce(self,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
op: ReduceOp = ReduceOp.SUM,
|
op: ReduceOp = ReduceOp.SUM,
|
||||||
stream=None):
|
stream=None):
|
||||||
|
# nccl communicator created on a specific device
|
||||||
|
# will only work on tensors on the same device
|
||||||
|
# otherwise it will cause "illegal memory access"
|
||||||
|
assert tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {tensor.device}")
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
NCCL_CHECK(
|
||||||
|
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
||||||
ctypes.c_void_p(tensor.data_ptr()),
|
ctypes.c_void_p(tensor.data_ptr()),
|
||||||
tensor.numel(),
|
tensor.numel(),
|
||||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||||
ctypes.c_void_p(stream.cuda_stream))
|
ctypes.c_void_p(stream.cuda_stream)))
|
||||||
assert result == 0
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
# `dist` module might have been already destroyed
|
# `dist` module might have been already destroyed
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import contextlib
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ProcessGroup, ReduceOp
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def init_process_group(world_size: int,
|
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
|
||||||
rank: int,
|
|
||||||
init_method: str,
|
|
||||||
local_rank: int = -1) -> None:
|
|
||||||
assert not is_initialized()
|
assert not is_initialized()
|
||||||
global comm
|
global comm
|
||||||
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
|
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
|
||||||
comm = NCCLCommunicator(init_method=init_method,
|
comm = NCCLCommunicator(group=group)
|
||||||
world_size=world_size,
|
|
||||||
local_rank=local_rank,
|
|
||||||
rank=rank)
|
|
||||||
|
|
||||||
|
|
||||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
"""Tensor and pipeline parallel groups."""
|
"""Tensor and pipeline parallel groups."""
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -73,6 +74,11 @@ def init_distributed_environment(
|
|||||||
ranks = list(range(torch.distributed.get_world_size()))
|
ranks = list(range(torch.distributed.get_world_size()))
|
||||||
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
|
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
|
||||||
backend="gloo")
|
backend="gloo")
|
||||||
|
# set the local rank
|
||||||
|
# local_rank is not available in torch ProcessGroup,
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/122816
|
||||||
|
if local_rank == -1 and distributed_init_method == "env://":
|
||||||
|
local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
global _LOCAL_RANK
|
global _LOCAL_RANK
|
||||||
_LOCAL_RANK = local_rank
|
_LOCAL_RANK = local_rank
|
||||||
|
|
||||||
|
|||||||
@ -298,12 +298,9 @@ def init_worker_distributed_environment(
|
|||||||
elif parallel_config.world_size > 1:
|
elif parallel_config.world_size > 1:
|
||||||
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
||||||
# is 1.
|
# is 1.
|
||||||
pynccl_utils.init_process_group(
|
# NOTE(kaichao): By default, pynccl will use information inside
|
||||||
world_size=parallel_config.world_size,
|
# `parallel_state` for initialization.
|
||||||
local_rank=local_rank,
|
pynccl_utils.init_process_group()
|
||||||
rank=rank,
|
|
||||||
init_method=distributed_init_method,
|
|
||||||
)
|
|
||||||
|
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user