diff --git a/requirements-common.txt b/requirements-common.txt index 765568b0..e874c4af 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -21,3 +21,4 @@ lm-format-enforcer == 0.10.1 outlines >= 0.0.43 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +pyzmq diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index 4880bab7..2d886eb5 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -2,10 +2,11 @@ import os import torch -from vllm.distributed.parallel_state import is_in_the_same_node +from vllm.distributed.parallel_state import in_the_same_node_as torch.distributed.init_process_group(backend="gloo") -test_result = is_in_the_same_node(torch.distributed.group.WORLD) +test_result = all( + in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0)) expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" assert test_result == expected, f"Expected {expected}, got {test_result}" diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 2c2466f8..2761b7f6 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -6,8 +6,7 @@ from typing import List import numpy as np import torch.distributed as dist -from vllm.distributed.device_communicators.shm_broadcast import ( - ShmRingBuffer, ShmRingBufferIO) +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.utils import update_environment_variables @@ -56,8 +55,8 @@ def worker_fn_wrapper(fn): @worker_fn_wrapper def worker_fn(): writer_rank = 2 - broadcaster = ShmRingBufferIO.create_from_process_group( - dist.group.WORLD, 1024 * 1024, 2, writer_rank) + broadcaster = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank) if dist.get_rank() == writer_rank: seed = random.randint(0, 1000) dist.broadcast_object_list([seed], writer_rank) @@ -87,13 +86,3 @@ def worker_fn(): def test_shm_broadcast(): distributed_run(worker_fn, 4) - - -def test_singe_process(): - buffer = ShmRingBuffer(1, 1024, 4) - reader = ShmRingBufferIO(buffer, reader_rank=0) - writer = ShmRingBufferIO(buffer, reader_rank=-1) - writer.enqueue([0]) - writer.enqueue([1]) - assert reader.dequeue() == [0] - assert reader.dequeue() == [1] diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a303d0bd..a4f30808 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -9,7 +9,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) -from vllm.distributed.parallel_state import is_in_the_same_node +from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless, is_full_nvlink @@ -64,7 +64,7 @@ class CustomAllreduce: assert dist.get_backend(group) != dist.Backend.NCCL, ( "CustomAllreduce should be attached to a non-NCCL group.") - if not is_in_the_same_node(group): + if not all(in_the_same_node_as(group, source_rank=0)): # No need to initialize custom allreduce for multi-node case. logger.warning( "Custom allreduce is disabled because this process group" diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index bea20588..db006495 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,16 +1,19 @@ import pickle import time from contextlib import contextmanager +from dataclasses import dataclass, field from multiprocessing import shared_memory -from typing import Optional +from typing import List, Optional from unittest.mock import patch import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore import vllm.envs as envs from vllm.logger import init_logger +from vllm.utils import get_ip, get_open_port VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -135,18 +138,183 @@ class ShmRingBuffer: yield buf -class ShmRingBufferIO: +@dataclass +class Handle: + connect_ip: str + local_reader_ranks: List[int] = field(default_factory=list) - def __init__(self, buffer: ShmRingBuffer, reader_rank: int): - self.buffer = buffer - self.reader_rank = reader_rank - self._is_writer = self.reader_rank == -1 - self._is_reader = not self._is_writer - if self._is_reader: - assert 0 <= self.reader_rank < buffer.n_reader, \ - (f"Invalid reader rank {self.reader_rank} for buffer" - f" created with {buffer.n_reader} readers") - self.current_idx = 0 + buffer: Optional[ShmRingBuffer] = None + local_subscribe_port: Optional[int] = None + local_sync_port: Optional[int] = None + remote_subscribe_port: Optional[int] = None + remote_sync_port: Optional[int] = None + + +class MessageQueue: + + def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[List[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, + ): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + if connect_ip is None: + connect_ip = get_ip() + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, + max_chunks) + + self.local_socket = context.socket(PUB) + local_subscribe_port = get_open_port() + self.local_socket.bind(f"tcp://*:{local_subscribe_port}") + + self.local_sync_socket = context.socket(REP) + local_sync_port = get_open_port() + self.local_sync_socket.bind(f"tcp://*:{local_sync_port}") + self.current_idx = 0 + + else: + self.buffer = None # type: ignore + local_subscribe_port = None + local_sync_port = None + self.local_socket = None + self.local_sync_socket = None + self.current_idx = -1 + + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + self.remote_socket = context.socket(PUB) + remote_subscribe_port = get_open_port() + self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") + + self.remote_sync_socket = context.socket(REP) + remote_sync_port = get_open_port() + self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}") + else: + remote_subscribe_port = None + remote_sync_port = None + self.remote_socket = None + self.remote_sync_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + + self.handle = Handle( + connect_ip=connect_ip, + local_reader_ranks=local_reader_ranks, + buffer=self.buffer, + local_subscribe_port=local_subscribe_port, + local_sync_port=local_sync_port, + remote_subscribe_port=remote_subscribe_port, + remote_sync_port=remote_sync_port, + ) + + def export_handle(self) -> Handle: + return self.handle + + @staticmethod + def create_from_handle(handle: Handle, rank) -> "MessageQueue": + self = MessageQueue.__new__(MessageQueue) + self.handle = handle + self._is_writer = False + + context = Context() + + if rank in handle.local_reader_ranks: + assert handle.buffer is not None + self.buffer = handle.buffer + self.current_idx = 0 + self.local_reader_rank = handle.local_reader_ranks.index(rank) + self._is_local_reader = True + self._is_remote_reader = False + + self.local_socket = context.socket(SUB) + self.local_socket.setsockopt_string(SUBSCRIBE, "") + self.local_socket.connect( + f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") + + self.local_sync_socket = context.socket(REQ) + self.local_sync_socket.connect( + f"tcp://{handle.connect_ip}:{handle.local_sync_port}") + + self.remote_socket = None + self.remote_sync_socket = None + else: + self.buffer = None # type: ignore + self.current_idx = -1 + self.local_reader_rank = -1 + self._is_local_reader = False + self._is_remote_reader = True + + self.local_socket = None + self.local_sync_socket = None + + self.remote_socket = context.socket(SUB) + self.remote_socket.setsockopt_string(SUBSCRIBE, "") + self.remote_socket.connect( + f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") + + self.remote_sync_socket = context.socket(REQ) + self.remote_sync_socket.connect( + f"tcp://{handle.connect_ip}:{handle.remote_sync_port}") + + return self + + def wait_until_ready(self): + """This is a collective operation. All processes (including the + readers and the writer) should call this function. + """ + if self._is_writer: + # wait for all readers to connect + + # local readers + for i in range(self.n_local_reader): + recv = self.local_sync_socket.recv() + assert recv == b"READY" + self.local_sync_socket.send(b"READY") + if self.n_local_reader > 0: + self.local_socket.send(b"READY") + + # remote readers + for i in range(self.n_remote_reader): + recv = self.remote_sync_socket.recv() + assert recv == b"READY" + self.remote_sync_socket.send(b"READY") + if self.n_remote_reader > 0: + self.remote_socket.send(b"READY") + elif self._is_local_reader: + self.local_sync_socket.send(b"READY") + recv = self.local_sync_socket.recv() + assert recv == b"READY" + recv = self.local_socket.recv() + assert recv == b"READY" + elif self._is_remote_reader: + self.remote_sync_socket.send(b"READY") + recv = self.remote_sync_socket.recv() + assert recv == b"READY" + recv = self.remote_socket.recv() + assert recv == b"READY" @contextmanager def acquire_write(self): @@ -201,12 +369,12 @@ class ShmRingBufferIO: @contextmanager def acquire_read(self): - assert self._is_reader, "Only readers can acquire read" + assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: - read_flag = metadata_buffer[self.reader_rank + 1] + read_flag = metadata_buffer[self.local_reader_rank + 1] written_flag = metadata_buffer[0] if not written_flag or read_flag: # this block is either @@ -236,7 +404,7 @@ class ShmRingBufferIO: # caller has read from the buffer # set the read flag - metadata_buffer[self.reader_rank + 1] = 1 + metadata_buffer[self.local_reader_rank + 1] = 1 self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break @@ -244,21 +412,36 @@ class ShmRingBufferIO: def enqueue(self, obj): assert self._is_writer, "Only writers can enqueue" serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) - if len(serialized_obj) > self.buffer.max_chunk_bytes: - raise RuntimeError( - f"{len(serialized_obj)=} larger than the allowed value " - f"{self.buffer.max_chunk_bytes}," - "Please increase the max_chunk_bytes parameter.") - with self.acquire_write() as buf: - buf[:len(serialized_obj)] = serialized_obj + if self.n_local_reader > 0: + if len(serialized_obj) >= self.buffer.max_chunk_bytes: + with self.acquire_write() as buf: + buf[0] = 1 # overflow + self.local_socket.send(serialized_obj) + else: + with self.acquire_write() as buf: + buf[0] = 0 # not overflow + buf[1:len(serialized_obj) + 1] = serialized_obj + if self.n_remote_reader > 0: + self.remote_socket.send(serialized_obj) def dequeue(self): - assert self._is_reader, "Only readers can dequeue" - with self.acquire_read() as buf: - # no need to know the size of serialized object - # pickle format itself contains the size information internally - # see https://docs.python.org/3/library/pickle.html - obj = pickle.loads(buf) + if self._is_local_reader: + overflow = False + with self.acquire_read() as buf: + overflow = buf[0] == 1 + if not overflow: + # no need to know the size of serialized object + # pickle format contains the size information internally + # see https://docs.python.org/3/library/pickle.html + obj = pickle.loads(buf[1:]) + if overflow: + recv = self.local_socket.recv() + obj = pickle.loads(recv) + elif self._is_remote_reader: + recv = self.remote_socket.recv() + obj = pickle.loads(recv) + else: + raise RuntimeError("Only readers can dequeue") return obj def broadcast_object(self, obj=None): @@ -272,24 +455,36 @@ class ShmRingBufferIO: def create_from_process_group(pg: ProcessGroup, max_chunk_bytes, max_chunks, - writer_rank=0) -> "ShmRingBufferIO": + writer_rank=0) -> "MessageQueue": group_rank = dist.get_rank(pg) group_world_size = dist.get_world_size(pg) - ranks_inside_group = list(range(group_world_size)) global_ranks = dist.get_process_group_ranks(pg) + + from vllm.distributed.parallel_state import in_the_same_node_as + status = in_the_same_node_as(pg, source_rank=writer_rank) + same_node_ranks = [i for i, s in enumerate(status) if s] n_reader = group_world_size - 1 - buffer: ShmRingBuffer + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io: MessageQueue if group_rank == writer_rank: - buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks) - dist.broadcast_object_list([buffer], + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + dist.broadcast_object_list([handle], src=global_ranks[writer_rank], group=pg) - return ShmRingBufferIO(buffer, -1) else: recv = [None] dist.broadcast_object_list(recv, src=global_ranks[writer_rank], group=pg) - buffer = recv[0] # type: ignore - rest_ranks = [r for r in ranks_inside_group if r != writer_rank] - return ShmRingBufferIO(buffer, rest_ranks.index(group_rank)) + handle = recv[0] # type: ignore + buffer_io = MessageQueue.create_from_handle(handle, group_rank) + buffer_io.wait_until_ready() + return buffer_io diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 66ffe6e8..128096c8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -124,7 +124,7 @@ class GroupCoordinator: # communicators are only created for world size > 1 pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator - shm_broadcaster: Optional[Any] # shared memory broadcaster + mq_broadcaster: Optional[Any] # shared memory broadcaster def __init__( self, @@ -133,6 +133,7 @@ class GroupCoordinator: torch_distributed_backend: Union[str, Backend], use_pynccl: bool, use_custom_allreduce: bool, + use_message_queue_broadcaster: bool = False, ): self.rank = torch.distributed.get_rank() @@ -190,10 +191,10 @@ class GroupCoordinator: self.ca_comm = None from vllm.distributed.device_communicators.shm_broadcast import ( - ShmRingBufferIO) - self.shm_broadcaster: Optional[ShmRingBufferIO] = None - if self.world_size > 1 and is_in_the_same_node(self.cpu_group): - self.shm_broadcaster = ShmRingBufferIO.create_from_process_group( + MessageQueue) + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( self.cpu_group, 1 << 22, 6) @property @@ -377,9 +378,9 @@ class GroupCoordinator: # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return obj - if self.shm_broadcaster is not None: - assert src == 0, "Shared memory broadcaster only supports src=0" - return self.shm_broadcaster.broadcast_object(obj) + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) if self.rank_in_group == src: torch.distributed.broadcast_object_list([obj], src=self.ranks[src], @@ -696,8 +697,8 @@ class GroupCoordinator: self.pynccl_comm = None if self.ca_comm is not None: self.ca_comm = None - if self.shm_broadcaster is not None: - self.shm_broadcaster = None + if self.mq_broadcaster is not None: + self.mq_broadcaster = None _WORLD: Optional[GroupCoordinator] = None @@ -720,10 +721,12 @@ def init_world_group(ranks: List[int], local_rank: int, def init_model_parallel_group( - group_ranks: List[List[int]], - local_rank: int, - backend: str, - use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator: + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None, + use_message_queue_broadcaster: bool = False, +) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE return GroupCoordinator( @@ -732,6 +735,7 @@ def init_model_parallel_group( torch_distributed_backend=backend, use_pynccl=True, use_custom_allreduce=use_custom_allreduce, + use_message_queue_broadcaster=use_message_queue_broadcaster, ) @@ -880,8 +884,12 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, backend) + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -993,15 +1001,15 @@ def destroy_distributed_environment(): torch.distributed.destroy_process_group() -def is_in_the_same_node(pg: ProcessGroup): +def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: """ - This is a collective operation that checks if all processes in the group - are in the same node. It tests if all processes are attached to the same + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory). """ assert torch.distributed.get_backend( pg) != torch.distributed.Backend.NCCL, ( - "is_in_the_same_node should be tested with a non-NCCL group.") + "in_the_same_node_as should be tested with a non-NCCL group.") # local rank inside the group rank = torch.distributed.get_rank(group=pg) world_size = torch.distributed.get_world_size(group=pg) @@ -1017,19 +1025,19 @@ def is_in_the_same_node(pg: ProcessGroup): try: with contextlib.suppress(OSError): - if rank == 0: + if rank == source_rank: # create a shared memory segment shm = shared_memory.SharedMemory(create=True, size=128) shm.buf[:len(magic_message)] = magic_message torch.distributed.broadcast_object_list([shm.name], - src=ranks[0], + src=ranks[source_rank], group=pg) - is_in_the_same_node[0] = 1 + is_in_the_same_node[rank] = 1 else: # try to open the shared memory segment recv = [None] torch.distributed.broadcast_object_list(recv, - src=ranks[0], + src=ranks[source_rank], group=pg) name = recv[0] # fix to https://stackoverflow.com/q/62748654/9191338 @@ -1050,8 +1058,8 @@ def is_in_the_same_node(pg: ProcessGroup): # clean up the shared memory segment with contextlib.suppress(OSError): - if rank == 0 and shm: + if rank == source_rank and shm: shm.unlink() torch.distributed.all_reduce(is_in_the_same_node, group=pg) - return is_in_the_same_node.sum().item() == world_size + return [x == 1 for x in is_in_the_same_node.tolist()]