[core][distributed] zmq fallback for broadcasting large objects (#6183)
[core][distributed] add zmq fallback for broadcasting large objects (#6183)
This commit is contained in:
parent
2416b26e11
commit
da78caecfa
@ -21,3 +21,4 @@ lm-format-enforcer == 0.10.1
|
|||||||
outlines >= 0.0.43 # Requires torch >= 2.1.0
|
outlines >= 0.0.43 # Requires torch >= 2.1.0
|
||||||
typing_extensions
|
typing_extensions
|
||||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||||
|
pyzmq
|
||||||
|
|||||||
@ -2,10 +2,11 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed.parallel_state import is_in_the_same_node
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||||
|
|
||||||
torch.distributed.init_process_group(backend="gloo")
|
torch.distributed.init_process_group(backend="gloo")
|
||||||
test_result = is_in_the_same_node(torch.distributed.group.WORLD)
|
test_result = all(
|
||||||
|
in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0))
|
||||||
|
|
||||||
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
||||||
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
||||||
|
|||||||
@ -6,8 +6,7 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||||
ShmRingBuffer, ShmRingBufferIO)
|
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
@ -56,8 +55,8 @@ def worker_fn_wrapper(fn):
|
|||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
writer_rank = 2
|
writer_rank = 2
|
||||||
broadcaster = ShmRingBufferIO.create_from_process_group(
|
broadcaster = MessageQueue.create_from_process_group(
|
||||||
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
|
dist.group.WORLD, 40 * 1024, 2, writer_rank)
|
||||||
if dist.get_rank() == writer_rank:
|
if dist.get_rank() == writer_rank:
|
||||||
seed = random.randint(0, 1000)
|
seed = random.randint(0, 1000)
|
||||||
dist.broadcast_object_list([seed], writer_rank)
|
dist.broadcast_object_list([seed], writer_rank)
|
||||||
@ -87,13 +86,3 @@ def worker_fn():
|
|||||||
|
|
||||||
def test_shm_broadcast():
|
def test_shm_broadcast():
|
||||||
distributed_run(worker_fn, 4)
|
distributed_run(worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
def test_singe_process():
|
|
||||||
buffer = ShmRingBuffer(1, 1024, 4)
|
|
||||||
reader = ShmRingBufferIO(buffer, reader_rank=0)
|
|
||||||
writer = ShmRingBufferIO(buffer, reader_rank=-1)
|
|
||||||
writer.enqueue([0])
|
|
||||||
writer.enqueue([1])
|
|
||||||
assert reader.dequeue() == [0]
|
|
||||||
assert reader.dequeue() == [1]
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import vllm.envs as envs
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||||
gpu_p2p_access_check)
|
gpu_p2p_access_check)
|
||||||
from vllm.distributed.parallel_state import is_in_the_same_node
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
|
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ class CustomAllreduce:
|
|||||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||||
|
|
||||||
if not is_in_the_same_node(group):
|
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||||
# No need to initialize custom allreduce for multi-node case.
|
# No need to initialize custom allreduce for multi-node case.
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Custom allreduce is disabled because this process group"
|
"Custom allreduce is disabled because this process group"
|
||||||
|
|||||||
@ -1,16 +1,19 @@
|
|||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import get_ip, get_open_port
|
||||||
|
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||||
|
|
||||||
@ -135,18 +138,183 @@ class ShmRingBuffer:
|
|||||||
yield buf
|
yield buf
|
||||||
|
|
||||||
|
|
||||||
class ShmRingBufferIO:
|
@dataclass
|
||||||
|
class Handle:
|
||||||
|
connect_ip: str
|
||||||
|
local_reader_ranks: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
|
buffer: Optional[ShmRingBuffer] = None
|
||||||
self.buffer = buffer
|
local_subscribe_port: Optional[int] = None
|
||||||
self.reader_rank = reader_rank
|
local_sync_port: Optional[int] = None
|
||||||
self._is_writer = self.reader_rank == -1
|
remote_subscribe_port: Optional[int] = None
|
||||||
self._is_reader = not self._is_writer
|
remote_sync_port: Optional[int] = None
|
||||||
if self._is_reader:
|
|
||||||
assert 0 <= self.reader_rank < buffer.n_reader, \
|
|
||||||
(f"Invalid reader rank {self.reader_rank} for buffer"
|
class MessageQueue:
|
||||||
f" created with {buffer.n_reader} readers")
|
|
||||||
self.current_idx = 0
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_reader, # number of all readers
|
||||||
|
n_local_reader, # number of local readers through shared memory
|
||||||
|
local_reader_ranks: Optional[List[int]] = None,
|
||||||
|
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||||
|
max_chunks: int = 10,
|
||||||
|
connect_ip: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if local_reader_ranks is None:
|
||||||
|
local_reader_ranks = list(range(n_local_reader))
|
||||||
|
else:
|
||||||
|
assert len(local_reader_ranks) == n_local_reader
|
||||||
|
self.n_local_reader = n_local_reader
|
||||||
|
n_remote_reader = n_reader - n_local_reader
|
||||||
|
self.n_remote_reader = n_remote_reader
|
||||||
|
|
||||||
|
if connect_ip is None:
|
||||||
|
connect_ip = get_ip()
|
||||||
|
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
if n_local_reader > 0:
|
||||||
|
# for local readers, we will:
|
||||||
|
# 1. create a shared memory ring buffer to communicate small data
|
||||||
|
# 2. create a publish-subscribe socket to communicate large data
|
||||||
|
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||||
|
max_chunks)
|
||||||
|
|
||||||
|
self.local_socket = context.socket(PUB)
|
||||||
|
local_subscribe_port = get_open_port()
|
||||||
|
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
|
||||||
|
|
||||||
|
self.local_sync_socket = context.socket(REP)
|
||||||
|
local_sync_port = get_open_port()
|
||||||
|
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
|
||||||
|
self.current_idx = 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.buffer = None # type: ignore
|
||||||
|
local_subscribe_port = None
|
||||||
|
local_sync_port = None
|
||||||
|
self.local_socket = None
|
||||||
|
self.local_sync_socket = None
|
||||||
|
self.current_idx = -1
|
||||||
|
|
||||||
|
if n_remote_reader > 0:
|
||||||
|
# for remote readers, we will:
|
||||||
|
# create a publish-subscribe socket to communicate large data
|
||||||
|
self.remote_socket = context.socket(PUB)
|
||||||
|
remote_subscribe_port = get_open_port()
|
||||||
|
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
|
||||||
|
|
||||||
|
self.remote_sync_socket = context.socket(REP)
|
||||||
|
remote_sync_port = get_open_port()
|
||||||
|
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
|
||||||
|
else:
|
||||||
|
remote_subscribe_port = None
|
||||||
|
remote_sync_port = None
|
||||||
|
self.remote_socket = None
|
||||||
|
self.remote_sync_socket = None
|
||||||
|
|
||||||
|
self._is_writer = True
|
||||||
|
self._is_local_reader = False
|
||||||
|
self.local_reader_rank = -1
|
||||||
|
# rank does not matter for remote readers
|
||||||
|
self._is_remote_reader = False
|
||||||
|
|
||||||
|
self.handle = Handle(
|
||||||
|
connect_ip=connect_ip,
|
||||||
|
local_reader_ranks=local_reader_ranks,
|
||||||
|
buffer=self.buffer,
|
||||||
|
local_subscribe_port=local_subscribe_port,
|
||||||
|
local_sync_port=local_sync_port,
|
||||||
|
remote_subscribe_port=remote_subscribe_port,
|
||||||
|
remote_sync_port=remote_sync_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
def export_handle(self) -> Handle:
|
||||||
|
return self.handle
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
|
||||||
|
self = MessageQueue.__new__(MessageQueue)
|
||||||
|
self.handle = handle
|
||||||
|
self._is_writer = False
|
||||||
|
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
if rank in handle.local_reader_ranks:
|
||||||
|
assert handle.buffer is not None
|
||||||
|
self.buffer = handle.buffer
|
||||||
|
self.current_idx = 0
|
||||||
|
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
||||||
|
self._is_local_reader = True
|
||||||
|
self._is_remote_reader = False
|
||||||
|
|
||||||
|
self.local_socket = context.socket(SUB)
|
||||||
|
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
||||||
|
self.local_socket.connect(
|
||||||
|
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
|
||||||
|
|
||||||
|
self.local_sync_socket = context.socket(REQ)
|
||||||
|
self.local_sync_socket.connect(
|
||||||
|
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
|
||||||
|
|
||||||
|
self.remote_socket = None
|
||||||
|
self.remote_sync_socket = None
|
||||||
|
else:
|
||||||
|
self.buffer = None # type: ignore
|
||||||
|
self.current_idx = -1
|
||||||
|
self.local_reader_rank = -1
|
||||||
|
self._is_local_reader = False
|
||||||
|
self._is_remote_reader = True
|
||||||
|
|
||||||
|
self.local_socket = None
|
||||||
|
self.local_sync_socket = None
|
||||||
|
|
||||||
|
self.remote_socket = context.socket(SUB)
|
||||||
|
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||||
|
self.remote_socket.connect(
|
||||||
|
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
|
||||||
|
|
||||||
|
self.remote_sync_socket = context.socket(REQ)
|
||||||
|
self.remote_sync_socket.connect(
|
||||||
|
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def wait_until_ready(self):
|
||||||
|
"""This is a collective operation. All processes (including the
|
||||||
|
readers and the writer) should call this function.
|
||||||
|
"""
|
||||||
|
if self._is_writer:
|
||||||
|
# wait for all readers to connect
|
||||||
|
|
||||||
|
# local readers
|
||||||
|
for i in range(self.n_local_reader):
|
||||||
|
recv = self.local_sync_socket.recv()
|
||||||
|
assert recv == b"READY"
|
||||||
|
self.local_sync_socket.send(b"READY")
|
||||||
|
if self.n_local_reader > 0:
|
||||||
|
self.local_socket.send(b"READY")
|
||||||
|
|
||||||
|
# remote readers
|
||||||
|
for i in range(self.n_remote_reader):
|
||||||
|
recv = self.remote_sync_socket.recv()
|
||||||
|
assert recv == b"READY"
|
||||||
|
self.remote_sync_socket.send(b"READY")
|
||||||
|
if self.n_remote_reader > 0:
|
||||||
|
self.remote_socket.send(b"READY")
|
||||||
|
elif self._is_local_reader:
|
||||||
|
self.local_sync_socket.send(b"READY")
|
||||||
|
recv = self.local_sync_socket.recv()
|
||||||
|
assert recv == b"READY"
|
||||||
|
recv = self.local_socket.recv()
|
||||||
|
assert recv == b"READY"
|
||||||
|
elif self._is_remote_reader:
|
||||||
|
self.remote_sync_socket.send(b"READY")
|
||||||
|
recv = self.remote_sync_socket.recv()
|
||||||
|
assert recv == b"READY"
|
||||||
|
recv = self.remote_socket.recv()
|
||||||
|
assert recv == b"READY"
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def acquire_write(self):
|
def acquire_write(self):
|
||||||
@ -201,12 +369,12 @@ class ShmRingBufferIO:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def acquire_read(self):
|
def acquire_read(self):
|
||||||
assert self._is_reader, "Only readers can acquire read"
|
assert self._is_local_reader, "Only readers can acquire read"
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
n_warning = 1
|
n_warning = 1
|
||||||
while True:
|
while True:
|
||||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||||
read_flag = metadata_buffer[self.reader_rank + 1]
|
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
||||||
written_flag = metadata_buffer[0]
|
written_flag = metadata_buffer[0]
|
||||||
if not written_flag or read_flag:
|
if not written_flag or read_flag:
|
||||||
# this block is either
|
# this block is either
|
||||||
@ -236,7 +404,7 @@ class ShmRingBufferIO:
|
|||||||
|
|
||||||
# caller has read from the buffer
|
# caller has read from the buffer
|
||||||
# set the read flag
|
# set the read flag
|
||||||
metadata_buffer[self.reader_rank + 1] = 1
|
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||||
self.current_idx = (self.current_idx +
|
self.current_idx = (self.current_idx +
|
||||||
1) % self.buffer.max_chunks
|
1) % self.buffer.max_chunks
|
||||||
break
|
break
|
||||||
@ -244,21 +412,36 @@ class ShmRingBufferIO:
|
|||||||
def enqueue(self, obj):
|
def enqueue(self, obj):
|
||||||
assert self._is_writer, "Only writers can enqueue"
|
assert self._is_writer, "Only writers can enqueue"
|
||||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
if len(serialized_obj) > self.buffer.max_chunk_bytes:
|
if self.n_local_reader > 0:
|
||||||
raise RuntimeError(
|
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
||||||
f"{len(serialized_obj)=} larger than the allowed value "
|
with self.acquire_write() as buf:
|
||||||
f"{self.buffer.max_chunk_bytes},"
|
buf[0] = 1 # overflow
|
||||||
"Please increase the max_chunk_bytes parameter.")
|
self.local_socket.send(serialized_obj)
|
||||||
with self.acquire_write() as buf:
|
else:
|
||||||
buf[:len(serialized_obj)] = serialized_obj
|
with self.acquire_write() as buf:
|
||||||
|
buf[0] = 0 # not overflow
|
||||||
|
buf[1:len(serialized_obj) + 1] = serialized_obj
|
||||||
|
if self.n_remote_reader > 0:
|
||||||
|
self.remote_socket.send(serialized_obj)
|
||||||
|
|
||||||
def dequeue(self):
|
def dequeue(self):
|
||||||
assert self._is_reader, "Only readers can dequeue"
|
if self._is_local_reader:
|
||||||
with self.acquire_read() as buf:
|
overflow = False
|
||||||
# no need to know the size of serialized object
|
with self.acquire_read() as buf:
|
||||||
# pickle format itself contains the size information internally
|
overflow = buf[0] == 1
|
||||||
# see https://docs.python.org/3/library/pickle.html
|
if not overflow:
|
||||||
obj = pickle.loads(buf)
|
# no need to know the size of serialized object
|
||||||
|
# pickle format contains the size information internally
|
||||||
|
# see https://docs.python.org/3/library/pickle.html
|
||||||
|
obj = pickle.loads(buf[1:])
|
||||||
|
if overflow:
|
||||||
|
recv = self.local_socket.recv()
|
||||||
|
obj = pickle.loads(recv)
|
||||||
|
elif self._is_remote_reader:
|
||||||
|
recv = self.remote_socket.recv()
|
||||||
|
obj = pickle.loads(recv)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Only readers can dequeue")
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def broadcast_object(self, obj=None):
|
def broadcast_object(self, obj=None):
|
||||||
@ -272,24 +455,36 @@ class ShmRingBufferIO:
|
|||||||
def create_from_process_group(pg: ProcessGroup,
|
def create_from_process_group(pg: ProcessGroup,
|
||||||
max_chunk_bytes,
|
max_chunk_bytes,
|
||||||
max_chunks,
|
max_chunks,
|
||||||
writer_rank=0) -> "ShmRingBufferIO":
|
writer_rank=0) -> "MessageQueue":
|
||||||
group_rank = dist.get_rank(pg)
|
group_rank = dist.get_rank(pg)
|
||||||
group_world_size = dist.get_world_size(pg)
|
group_world_size = dist.get_world_size(pg)
|
||||||
ranks_inside_group = list(range(group_world_size))
|
|
||||||
global_ranks = dist.get_process_group_ranks(pg)
|
global_ranks = dist.get_process_group_ranks(pg)
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||||
|
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||||
|
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||||
n_reader = group_world_size - 1
|
n_reader = group_world_size - 1
|
||||||
buffer: ShmRingBuffer
|
n_local_reader = len(same_node_ranks) - 1
|
||||||
|
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||||
|
buffer_io: MessageQueue
|
||||||
if group_rank == writer_rank:
|
if group_rank == writer_rank:
|
||||||
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
|
buffer_io = MessageQueue(
|
||||||
dist.broadcast_object_list([buffer],
|
n_reader=n_reader,
|
||||||
|
n_local_reader=n_local_reader,
|
||||||
|
local_reader_ranks=local_reader_ranks,
|
||||||
|
max_chunk_bytes=max_chunk_bytes,
|
||||||
|
max_chunks=max_chunks,
|
||||||
|
)
|
||||||
|
handle = buffer_io.export_handle()
|
||||||
|
dist.broadcast_object_list([handle],
|
||||||
src=global_ranks[writer_rank],
|
src=global_ranks[writer_rank],
|
||||||
group=pg)
|
group=pg)
|
||||||
return ShmRingBufferIO(buffer, -1)
|
|
||||||
else:
|
else:
|
||||||
recv = [None]
|
recv = [None]
|
||||||
dist.broadcast_object_list(recv,
|
dist.broadcast_object_list(recv,
|
||||||
src=global_ranks[writer_rank],
|
src=global_ranks[writer_rank],
|
||||||
group=pg)
|
group=pg)
|
||||||
buffer = recv[0] # type: ignore
|
handle = recv[0] # type: ignore
|
||||||
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
|
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||||
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))
|
buffer_io.wait_until_ready()
|
||||||
|
return buffer_io
|
||||||
|
|||||||
@ -124,7 +124,7 @@ class GroupCoordinator:
|
|||||||
# communicators are only created for world size > 1
|
# communicators are only created for world size > 1
|
||||||
pynccl_comm: Optional[Any] # PyNccl communicator
|
pynccl_comm: Optional[Any] # PyNccl communicator
|
||||||
ca_comm: Optional[Any] # Custom allreduce communicator
|
ca_comm: Optional[Any] # Custom allreduce communicator
|
||||||
shm_broadcaster: Optional[Any] # shared memory broadcaster
|
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -133,6 +133,7 @@ class GroupCoordinator:
|
|||||||
torch_distributed_backend: Union[str, Backend],
|
torch_distributed_backend: Union[str, Backend],
|
||||||
use_pynccl: bool,
|
use_pynccl: bool,
|
||||||
use_custom_allreduce: bool,
|
use_custom_allreduce: bool,
|
||||||
|
use_message_queue_broadcaster: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
@ -190,10 +191,10 @@ class GroupCoordinator:
|
|||||||
self.ca_comm = None
|
self.ca_comm = None
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||||
ShmRingBufferIO)
|
MessageQueue)
|
||||||
self.shm_broadcaster: Optional[ShmRingBufferIO] = None
|
self.mq_broadcaster: Optional[MessageQueue] = None
|
||||||
if self.world_size > 1 and is_in_the_same_node(self.cpu_group):
|
if use_message_queue_broadcaster and self.world_size > 1:
|
||||||
self.shm_broadcaster = ShmRingBufferIO.create_from_process_group(
|
self.mq_broadcaster = MessageQueue.create_from_process_group(
|
||||||
self.cpu_group, 1 << 22, 6)
|
self.cpu_group, 1 << 22, 6)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -377,9 +378,9 @@ class GroupCoordinator:
|
|||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
if self.world_size == 1:
|
if self.world_size == 1:
|
||||||
return obj
|
return obj
|
||||||
if self.shm_broadcaster is not None:
|
if self.mq_broadcaster is not None:
|
||||||
assert src == 0, "Shared memory broadcaster only supports src=0"
|
assert src == 0, "Message queue broadcaster only supports src=0"
|
||||||
return self.shm_broadcaster.broadcast_object(obj)
|
return self.mq_broadcaster.broadcast_object(obj)
|
||||||
if self.rank_in_group == src:
|
if self.rank_in_group == src:
|
||||||
torch.distributed.broadcast_object_list([obj],
|
torch.distributed.broadcast_object_list([obj],
|
||||||
src=self.ranks[src],
|
src=self.ranks[src],
|
||||||
@ -696,8 +697,8 @@ class GroupCoordinator:
|
|||||||
self.pynccl_comm = None
|
self.pynccl_comm = None
|
||||||
if self.ca_comm is not None:
|
if self.ca_comm is not None:
|
||||||
self.ca_comm = None
|
self.ca_comm = None
|
||||||
if self.shm_broadcaster is not None:
|
if self.mq_broadcaster is not None:
|
||||||
self.shm_broadcaster = None
|
self.mq_broadcaster = None
|
||||||
|
|
||||||
|
|
||||||
_WORLD: Optional[GroupCoordinator] = None
|
_WORLD: Optional[GroupCoordinator] = None
|
||||||
@ -720,10 +721,12 @@ def init_world_group(ranks: List[int], local_rank: int,
|
|||||||
|
|
||||||
|
|
||||||
def init_model_parallel_group(
|
def init_model_parallel_group(
|
||||||
group_ranks: List[List[int]],
|
group_ranks: List[List[int]],
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
backend: str,
|
backend: str,
|
||||||
use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator:
|
use_custom_allreduce: Optional[bool] = None,
|
||||||
|
use_message_queue_broadcaster: bool = False,
|
||||||
|
) -> GroupCoordinator:
|
||||||
if use_custom_allreduce is None:
|
if use_custom_allreduce is None:
|
||||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||||
return GroupCoordinator(
|
return GroupCoordinator(
|
||||||
@ -732,6 +735,7 @@ def init_model_parallel_group(
|
|||||||
torch_distributed_backend=backend,
|
torch_distributed_backend=backend,
|
||||||
use_pynccl=True,
|
use_pynccl=True,
|
||||||
use_custom_allreduce=use_custom_allreduce,
|
use_custom_allreduce=use_custom_allreduce,
|
||||||
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -880,8 +884,12 @@ def initialize_model_parallel(
|
|||||||
range(i * tensor_model_parallel_size,
|
range(i * tensor_model_parallel_size,
|
||||||
(i + 1) * tensor_model_parallel_size))
|
(i + 1) * tensor_model_parallel_size))
|
||||||
group_ranks.append(ranks)
|
group_ranks.append(ranks)
|
||||||
|
|
||||||
|
# message queue broadcaster is only used in tensor model parallel group
|
||||||
_TP = init_model_parallel_group(group_ranks,
|
_TP = init_model_parallel_group(group_ranks,
|
||||||
get_world_group().local_rank, backend)
|
get_world_group().local_rank,
|
||||||
|
backend,
|
||||||
|
use_message_queue_broadcaster=True)
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups.
|
# Build the pipeline model-parallel groups.
|
||||||
num_pipeline_model_parallel_groups: int = (world_size //
|
num_pipeline_model_parallel_groups: int = (world_size //
|
||||||
@ -993,15 +1001,15 @@ def destroy_distributed_environment():
|
|||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def is_in_the_same_node(pg: ProcessGroup):
|
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
||||||
"""
|
"""
|
||||||
This is a collective operation that checks if all processes in the group
|
This is a collective operation that returns if each rank is in the same node
|
||||||
are in the same node. It tests if all processes are attached to the same
|
as the source rank. It tests if processes are attached to the same
|
||||||
memory system (shared access to shared memory).
|
memory system (shared access to shared memory).
|
||||||
"""
|
"""
|
||||||
assert torch.distributed.get_backend(
|
assert torch.distributed.get_backend(
|
||||||
pg) != torch.distributed.Backend.NCCL, (
|
pg) != torch.distributed.Backend.NCCL, (
|
||||||
"is_in_the_same_node should be tested with a non-NCCL group.")
|
"in_the_same_node_as should be tested with a non-NCCL group.")
|
||||||
# local rank inside the group
|
# local rank inside the group
|
||||||
rank = torch.distributed.get_rank(group=pg)
|
rank = torch.distributed.get_rank(group=pg)
|
||||||
world_size = torch.distributed.get_world_size(group=pg)
|
world_size = torch.distributed.get_world_size(group=pg)
|
||||||
@ -1017,19 +1025,19 @@ def is_in_the_same_node(pg: ProcessGroup):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
if rank == 0:
|
if rank == source_rank:
|
||||||
# create a shared memory segment
|
# create a shared memory segment
|
||||||
shm = shared_memory.SharedMemory(create=True, size=128)
|
shm = shared_memory.SharedMemory(create=True, size=128)
|
||||||
shm.buf[:len(magic_message)] = magic_message
|
shm.buf[:len(magic_message)] = magic_message
|
||||||
torch.distributed.broadcast_object_list([shm.name],
|
torch.distributed.broadcast_object_list([shm.name],
|
||||||
src=ranks[0],
|
src=ranks[source_rank],
|
||||||
group=pg)
|
group=pg)
|
||||||
is_in_the_same_node[0] = 1
|
is_in_the_same_node[rank] = 1
|
||||||
else:
|
else:
|
||||||
# try to open the shared memory segment
|
# try to open the shared memory segment
|
||||||
recv = [None]
|
recv = [None]
|
||||||
torch.distributed.broadcast_object_list(recv,
|
torch.distributed.broadcast_object_list(recv,
|
||||||
src=ranks[0],
|
src=ranks[source_rank],
|
||||||
group=pg)
|
group=pg)
|
||||||
name = recv[0]
|
name = recv[0]
|
||||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||||
@ -1050,8 +1058,8 @@ def is_in_the_same_node(pg: ProcessGroup):
|
|||||||
|
|
||||||
# clean up the shared memory segment
|
# clean up the shared memory segment
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
if rank == 0 and shm:
|
if rank == source_rank and shm:
|
||||||
shm.unlink()
|
shm.unlink()
|
||||||
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
||||||
|
|
||||||
return is_in_the_same_node.sum().item() == world_size
|
return [x == 1 for x in is_in_the_same_node.tolist()]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user