[core][distributed] add stateless process group (#10216)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-11 09:02:14 -08:00 committed by GitHub
parent 36fc439de0
commit e6de9784d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 206 additions and 101 deletions

View File

@ -1,10 +1,10 @@
import pytest import pytest
import ray import ray
import torch import torch
import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.utils import stateless_init_process_group from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import (cuda_device_count_stateless, from vllm.utils import (cuda_device_count_stateless,
update_environment_variables) update_environment_variables)
@ -41,42 +41,45 @@ def test_cuda_device_count_stateless():
def cpu_worker(rank, WORLD_SIZE): def cpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500", pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29500",
rank=rank, rank=rank,
world_size=WORLD_SIZE, world_size=WORLD_SIZE)
backend="gloo")
if rank <= 2: if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501", pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29501",
rank=rank, rank=rank,
world_size=3, world_size=3)
backend="gloo")
data = torch.tensor([rank]) data = torch.tensor([rank])
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2
if rank <= 2: if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) data = torch.tensor([rank + 1])
item = data[0].item() data = pg2.broadcast_obj(data, src=2)
print(f"rank: {rank}, item: {item}") assert data.item() == 3
if rank == 3: pg2.barrier()
assert item == 6 pg1.barrier()
else:
assert item == 18
def gpu_worker(rank, WORLD_SIZE): def gpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
rank=rank,
world_size=WORLD_SIZE,
backend="nccl")
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3,
backend="nccl")
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
data = torch.tensor([rank]).cuda() pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29502",
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2: if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
pynccl1.all_reduce(data)
pg1.barrier()
torch.cuda.synchronize()
if rank <= 2:
pynccl2.all_reduce(data)
pg2.barrier()
torch.cuda.synchronize()
item = data[0].item() item = data[0].item()
print(f"rank: {rank}, item: {item}") print(f"rank: {rank}, item: {item}")
if rank == 3: if rank == 3:
@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE):
assert item == 18 assert item == 18
def broadcast_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29504",
rank=rank,
world_size=WORLD_SIZE)
if rank == 2:
pg1.broadcast_obj("secret", src=2)
else:
obj = pg1.broadcast_obj(None, src=2)
assert obj == "secret"
pg1.barrier()
def allgather_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29505",
rank=rank,
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank)
assert data == list(range(WORLD_SIZE))
pg1.barrier()
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker]) @pytest.mark.parametrize(
def test_stateless_init_process_group(worker): "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
def test_stateless_process_group(worker):
WORLD_SIZE = 4 WORLD_SIZE = 4
from multiprocessing import get_context from multiprocessing import get_context
ctx = get_context("fork") ctx = get_context("fork")

View File

@ -9,6 +9,7 @@ from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId) ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@ -18,7 +19,7 @@ class PyNcclCommunicator:
def __init__( def __init__(
self, self,
group: ProcessGroup, group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
library_path: Optional[str] = None, library_path: Optional[str] = None,
): ):
@ -33,13 +34,18 @@ class PyNcclCommunicator:
It is the caller's responsibility to make sure each communicator It is the caller's responsibility to make sure each communicator
is bind to a unique device. is bind to a unique device.
""" """
assert dist.is_initialized() if not isinstance(group, StatelessProcessGroup):
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.is_initialized()
"PyNcclCommunicator should be attached to a non-NCCL group.") assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.")
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
# if world_size == 1, no need to create communicator # if world_size == 1, no need to create communicator
if self.world_size == 1: if self.world_size == 1:
@ -68,13 +74,17 @@ class PyNcclCommunicator:
else: else:
# construct an empty unique id # construct an empty unique id
self.unique_id = ncclUniqueId() self.unique_id = ncclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group) if not isinstance(group, StatelessProcessGroup):
# arg `src` in `broadcast` is the global rank tensor = torch.ByteTensor(list(self.unique_id.internal))
dist.broadcast(tensor, src=ranks[0], group=group) ranks = dist.get_process_group_ranks(group)
byte_list = tensor.tolist() # arg `src` in `broadcast` is the global rank
for i, byte in enumerate(byte_list): dist.broadcast(tensor, src=ranks[0], group=group)
self.unique_id.internal[i] = byte byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int): if isinstance(device, int):
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):

View File

@ -2,13 +2,13 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Sequence, Tuple import dataclasses
import pickle
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
@ -91,69 +91,139 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
return (start_layer, end_layer) return (start_layer, end_layer)
def stateless_init_process_group(init_method: str, rank: int, world_size: int, @dataclasses.dataclass
backend: str) -> ProcessGroup: class StatelessProcessGroup:
"""A replacement for `torch.distributed.init_process_group` that does not """A dataclass to hold a metadata store, and the rank, world_size of the
pollute the global state. group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
prefix: str
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
data_expiration_seconds: int = 3600 # 1 hour
If we have process A and process B called `torch.distributed.init_process_group` # dst rank -> counter
to form a group, and then we want to form another group with process A, B, C, send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
D, it is not possible in PyTorch, because process A and process B have already # src rank -> counter
formed a group, and process C and process D cannot join that group. This recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
function is a workaround for this issue. broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
default_factory=dict)
`torch.distributed.init_process_group` is a global call, while this function # A deque to store the data entries, with key and timestamp.
is a stateless call. It will return a `ProcessGroup` object that can be used entries: Deque[Tuple[str,
for collective communication. With this function, process A and process B float]] = dataclasses.field(default_factory=deque)
can call `stateless_init_process_group` to form a group, and then process A, B,
C, and D can call `stateless_init_process_group` to form another group.
""" # noqa
backend = Backend(backend) # it is basically string def __post_init__(self):
timeout = _get_default_timeout(backend) assert self.rank < self.world_size
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
self.broadcast_recv_src_counter = {
i: 0
for i in range(self.world_size)
}
store, rank, world_size = next( def send_obj(self, obj: Any, dst: int):
rendezvous(init_method, rank, world_size, timeout=timeout)) """Send an object to a destination rank."""
store.set_timeout(timeout) self.expire_data()
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
group_rank = rank def expire_data(self):
group_size = world_size """Expire data that is older than `data_expiration_seconds` seconds."""
while self.entries:
# check the oldest entry
key, timestamp = self.entries[0]
if time.time() - timestamp > self.data_expiration_seconds:
self.store.delete_key(key)
self.entries.popleft()
else:
break
# Use a PrefixStore to avoid accidental overrides of keys used by def recv_obj(self, src: int) -> Any:
# different systems (e.g. RPC) in case the store is multi-tenant. """Receive an object from a source rank."""
prefix_store = PrefixStore(init_method, store) obj = pickle.loads(
self.store.get(
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
))
self.recv_src_counter[src] += 1
return obj
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout) def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
"""Broadcast an object from a source rank to all other ranks.
It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
"""
if self.rank == src:
self.expire_data()
key = (f"{self.prefix}/broadcast_from/{src}/"
f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = (f"{self.prefix}/broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return recv_obj
pg: ProcessGroup = ProcessGroup( def all_gather_obj(self, obj: Any) -> list[Any]:
prefix_store, """All gather an object from all ranks."""
group_rank, gathered_objs = []
group_size, for i in range(self.world_size):
pg_options, if i == self.rank:
) gathered_objs.append(obj)
self.broadcast_obj(obj, src=self.rank)
else:
recv_obj = self.broadcast_obj(None, src=i)
gathered_objs.append(recv_obj)
return gathered_objs
if backend == "gloo": def barrier(self):
from torch.distributed.distributed_c10d import ProcessGroupGloo """A barrier to synchronize all ranks."""
backend_class = ProcessGroupGloo(prefix_store, for i in range(self.world_size):
group_rank, if i == self.rank:
group_size, self.broadcast_obj(None, src=self.rank)
timeout=timeout) else:
backend_type = ProcessGroup.BackendType.GLOO self.broadcast_obj(None, src=i)
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options() @staticmethod
backend_options._timeout = timeout def create(
init_method: str,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, If we have process A and process B called `torch.distributed.init_process_group`
backend_options) to form a group, and then we want to form another group with process A, B, C,
backend_type = ProcessGroup.BackendType.NCCL D, it is not possible in PyTorch, because process A and process B have already
device = torch.device("cuda") formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
backend_class._set_sequence_number_for_group() `torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `StatelessProcessGroup` object that can be
used for exchanging metadata. With this function, process A and process B
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
timeout = _DEFAULT_PG_TIMEOUT
pg._register_backend(device, backend_type, backend_class) store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
return pg return StatelessProcessGroup(
prefix=init_method,
rank=rank,
world_size=world_size,
store=store,
data_expiration_seconds=data_expiration_seconds)