[core][distributed] fix zmq hang (#6759)
This commit is contained in:
parent
d88c458f44
commit
740374d456
@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Optional
|
||||
from typing import Mapping, MutableMapping, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
@ -40,7 +40,7 @@ class HTTPConnection:
|
||||
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
|
||||
"must have scheme 'http' or 'https'.")
|
||||
|
||||
def _headers(self, **extras: str) -> Mapping[str, str]:
|
||||
def _headers(self, **extras: str) -> MutableMapping[str, str]:
|
||||
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
|
||||
|
||||
def get_response(
|
||||
|
||||
@ -9,7 +9,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
|
||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@ -153,9 +153,7 @@ class Handle:
|
||||
|
||||
buffer: Optional[ShmRingBuffer] = None
|
||||
local_subscribe_port: Optional[int] = None
|
||||
local_sync_port: Optional[int] = None
|
||||
remote_subscribe_port: Optional[int] = None
|
||||
remote_sync_port: Optional[int] = None
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
@ -189,38 +187,36 @@ class MessageQueue:
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||
max_chunks)
|
||||
|
||||
self.local_socket = context.socket(PUB)
|
||||
# XPUB is very similar to PUB,
|
||||
# except that it can receive subscription messages
|
||||
# to confirm the number of subscribers
|
||||
self.local_socket = context.socket(XPUB)
|
||||
# set the verbose option so that we can receive every subscription
|
||||
# message. otherwise, we will only receive the first subscription
|
||||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
local_subscribe_port = get_open_port()
|
||||
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
|
||||
|
||||
self.local_sync_socket = context.socket(REP)
|
||||
local_sync_port = get_open_port()
|
||||
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
|
||||
self.current_idx = 0
|
||||
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_port = None
|
||||
local_sync_port = None
|
||||
self.local_socket = None
|
||||
self.local_sync_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
self.remote_socket = context.socket(PUB)
|
||||
self.remote_socket = context.socket(XPUB)
|
||||
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
remote_subscribe_port = get_open_port()
|
||||
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
|
||||
|
||||
self.remote_sync_socket = context.socket(REP)
|
||||
remote_sync_port = get_open_port()
|
||||
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
|
||||
else:
|
||||
remote_subscribe_port = None
|
||||
remote_sync_port = None
|
||||
self.remote_socket = None
|
||||
self.remote_sync_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
self._is_local_reader = False
|
||||
@ -233,9 +229,7 @@ class MessageQueue:
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer=self.buffer,
|
||||
local_subscribe_port=local_subscribe_port,
|
||||
local_sync_port=local_sync_port,
|
||||
remote_subscribe_port=remote_subscribe_port,
|
||||
remote_sync_port=remote_sync_port,
|
||||
)
|
||||
|
||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||
@ -264,12 +258,7 @@ class MessageQueue:
|
||||
self.local_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
|
||||
|
||||
self.local_sync_socket = context.socket(REQ)
|
||||
self.local_sync_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
|
||||
|
||||
self.remote_socket = None
|
||||
self.remote_sync_socket = None
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
self.current_idx = -1
|
||||
@ -278,17 +267,12 @@ class MessageQueue:
|
||||
self._is_remote_reader = True
|
||||
|
||||
self.local_socket = None
|
||||
self.local_sync_socket = None
|
||||
|
||||
self.remote_socket = context.socket(SUB)
|
||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
self.remote_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
|
||||
|
||||
self.remote_sync_socket = context.socket(REQ)
|
||||
self.remote_sync_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
|
||||
|
||||
return self
|
||||
|
||||
def wait_until_ready(self):
|
||||
@ -300,29 +284,27 @@ class MessageQueue:
|
||||
|
||||
# local readers
|
||||
for i in range(self.n_local_reader):
|
||||
recv = self.local_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
self.local_sync_socket.send(b"READY")
|
||||
# wait for subscription messages from all local readers
|
||||
self.local_socket.recv()
|
||||
if self.n_local_reader > 0:
|
||||
# send a message to all local readers
|
||||
# to make sure the publish channel is working
|
||||
self.local_socket.send(b"READY")
|
||||
|
||||
# remote readers
|
||||
for i in range(self.n_remote_reader):
|
||||
recv = self.remote_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
self.remote_sync_socket.send(b"READY")
|
||||
# wait for subscription messages from all remote readers
|
||||
self.remote_socket.recv()
|
||||
if self.n_remote_reader > 0:
|
||||
# send a message to all remote readers
|
||||
# to make sure the publish channel is working
|
||||
self.remote_socket.send(b"READY")
|
||||
elif self._is_local_reader:
|
||||
self.local_sync_socket.send(b"READY")
|
||||
recv = self.local_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
# wait for the writer to send a message
|
||||
recv = self.local_socket.recv()
|
||||
assert recv == b"READY"
|
||||
elif self._is_remote_reader:
|
||||
self.remote_sync_socket.send(b"READY")
|
||||
recv = self.remote_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
# wait for the writer to send a message
|
||||
recv = self.remote_socket.recv()
|
||||
assert recv == b"READY"
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user