[core][distributed] fix zmq hang (#6759)

This commit is contained in:
youkaichao 2024-07-24 17:37:12 -07:00 committed by GitHub
parent d88c458f44
commit 740374d456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 41 deletions

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import Mapping, Optional
from typing import Mapping, MutableMapping, Optional
from urllib.parse import urlparse
import aiohttp
@ -40,7 +40,7 @@ class HTTPConnection:
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'.")
def _headers(self, **extras: str) -> Mapping[str, str]:
def _headers(self, **extras: str) -> MutableMapping[str, str]:
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
def get_response(

View File

@ -9,7 +9,7 @@ from unittest.mock import patch
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs
from vllm.logger import init_logger
@ -153,9 +153,7 @@ class Handle:
buffer: Optional[ShmRingBuffer] = None
local_subscribe_port: Optional[int] = None
local_sync_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None
remote_sync_port: Optional[int] = None
class MessageQueue:
@ -189,38 +187,36 @@ class MessageQueue:
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)
self.local_socket = context.socket(PUB)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self.local_socket = context.socket(XPUB)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
self.local_sync_socket = context.socket(REP)
local_sync_port = get_open_port()
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
self.current_idx = 0
else:
self.buffer = None # type: ignore
local_subscribe_port = None
local_sync_port = None
self.local_socket = None
self.local_sync_socket = None
self.current_idx = -1
if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
self.remote_socket = context.socket(PUB)
self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
self.remote_sync_socket = context.socket(REP)
remote_sync_port = get_open_port()
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
else:
remote_subscribe_port = None
remote_sync_port = None
self.remote_socket = None
self.remote_sync_socket = None
self._is_writer = True
self._is_local_reader = False
@ -233,9 +229,7 @@ class MessageQueue:
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
local_subscribe_port=local_subscribe_port,
local_sync_port=local_sync_port,
remote_subscribe_port=remote_subscribe_port,
remote_sync_port=remote_sync_port,
)
logger.info("vLLM message queue communication handle: %s", self.handle)
@ -264,12 +258,7 @@ class MessageQueue:
self.local_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
self.local_sync_socket = context.socket(REQ)
self.local_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
self.remote_socket = None
self.remote_sync_socket = None
else:
self.buffer = None # type: ignore
self.current_idx = -1
@ -278,17 +267,12 @@ class MessageQueue:
self._is_remote_reader = True
self.local_socket = None
self.local_sync_socket = None
self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
self.remote_sync_socket = context.socket(REQ)
self.remote_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
return self
def wait_until_ready(self):
@ -300,29 +284,27 @@ class MessageQueue:
# local readers
for i in range(self.n_local_reader):
recv = self.local_sync_socket.recv()
assert recv == b"READY"
self.local_sync_socket.send(b"READY")
# wait for subscription messages from all local readers
self.local_socket.recv()
if self.n_local_reader > 0:
# send a message to all local readers
# to make sure the publish channel is working
self.local_socket.send(b"READY")
# remote readers
for i in range(self.n_remote_reader):
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
self.remote_sync_socket.send(b"READY")
# wait for subscription messages from all remote readers
self.remote_socket.recv()
if self.n_remote_reader > 0:
# send a message to all remote readers
# to make sure the publish channel is working
self.remote_socket.send(b"READY")
elif self._is_local_reader:
self.local_sync_socket.send(b"READY")
recv = self.local_sync_socket.recv()
assert recv == b"READY"
# wait for the writer to send a message
recv = self.local_socket.recv()
assert recv == b"READY"
elif self._is_remote_reader:
self.remote_sync_socket.send(b"READY")
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
# wait for the writer to send a message
recv = self.remote_socket.recv()
assert recv == b"READY"