[core][distributed] fix zmq hang (#6759)

This commit is contained in:
youkaichao 2024-07-24 17:37:12 -07:00 committed by GitHub
parent d88c458f44
commit 740374d456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 41 deletions

View File

@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Mapping, Optional from typing import Mapping, MutableMapping, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
@ -40,7 +40,7 @@ class HTTPConnection:
raise ValueError("Invalid HTTP URL: A valid HTTP URL " raise ValueError("Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'.") "must have scheme 'http' or 'https'.")
def _headers(self, **extras: str) -> Mapping[str, str]: def _headers(self, **extras: str) -> MutableMapping[str, str]:
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
def get_response( def get_response(

View File

@ -9,7 +9,7 @@ from unittest.mock import patch
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
@ -153,9 +153,7 @@ class Handle:
buffer: Optional[ShmRingBuffer] = None buffer: Optional[ShmRingBuffer] = None
local_subscribe_port: Optional[int] = None local_subscribe_port: Optional[int] = None
local_sync_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None remote_subscribe_port: Optional[int] = None
remote_sync_port: Optional[int] = None
class MessageQueue: class MessageQueue:
@ -189,38 +187,36 @@ class MessageQueue:
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks) max_chunks)
self.local_socket = context.socket(PUB) # XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self.local_socket = context.socket(XPUB)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_port = get_open_port() local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}") self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
self.local_sync_socket = context.socket(REP)
local_sync_port = get_open_port()
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
self.current_idx = 0 self.current_idx = 0
else: else:
self.buffer = None # type: ignore self.buffer = None # type: ignore
local_subscribe_port = None local_subscribe_port = None
local_sync_port = None
self.local_socket = None self.local_socket = None
self.local_sync_socket = None
self.current_idx = -1 self.current_idx = -1
if n_remote_reader > 0: if n_remote_reader > 0:
# for remote readers, we will: # for remote readers, we will:
# create a publish-subscribe socket to communicate large data # create a publish-subscribe socket to communicate large data
self.remote_socket = context.socket(PUB) self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port() remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
self.remote_sync_socket = context.socket(REP)
remote_sync_port = get_open_port()
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
else: else:
remote_subscribe_port = None remote_subscribe_port = None
remote_sync_port = None
self.remote_socket = None self.remote_socket = None
self.remote_sync_socket = None
self._is_writer = True self._is_writer = True
self._is_local_reader = False self._is_local_reader = False
@ -233,9 +229,7 @@ class MessageQueue:
local_reader_ranks=local_reader_ranks, local_reader_ranks=local_reader_ranks,
buffer=self.buffer, buffer=self.buffer,
local_subscribe_port=local_subscribe_port, local_subscribe_port=local_subscribe_port,
local_sync_port=local_sync_port,
remote_subscribe_port=remote_subscribe_port, remote_subscribe_port=remote_subscribe_port,
remote_sync_port=remote_sync_port,
) )
logger.info("vLLM message queue communication handle: %s", self.handle) logger.info("vLLM message queue communication handle: %s", self.handle)
@ -264,12 +258,7 @@ class MessageQueue:
self.local_socket.connect( self.local_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
self.local_sync_socket = context.socket(REQ)
self.local_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
self.remote_socket = None self.remote_socket = None
self.remote_sync_socket = None
else: else:
self.buffer = None # type: ignore self.buffer = None # type: ignore
self.current_idx = -1 self.current_idx = -1
@ -278,17 +267,12 @@ class MessageQueue:
self._is_remote_reader = True self._is_remote_reader = True
self.local_socket = None self.local_socket = None
self.local_sync_socket = None
self.remote_socket = context.socket(SUB) self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "") self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect( self.remote_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
self.remote_sync_socket = context.socket(REQ)
self.remote_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
return self return self
def wait_until_ready(self): def wait_until_ready(self):
@ -300,29 +284,27 @@ class MessageQueue:
# local readers # local readers
for i in range(self.n_local_reader): for i in range(self.n_local_reader):
recv = self.local_sync_socket.recv() # wait for subscription messages from all local readers
assert recv == b"READY" self.local_socket.recv()
self.local_sync_socket.send(b"READY")
if self.n_local_reader > 0: if self.n_local_reader > 0:
# send a message to all local readers
# to make sure the publish channel is working
self.local_socket.send(b"READY") self.local_socket.send(b"READY")
# remote readers # remote readers
for i in range(self.n_remote_reader): for i in range(self.n_remote_reader):
recv = self.remote_sync_socket.recv() # wait for subscription messages from all remote readers
assert recv == b"READY" self.remote_socket.recv()
self.remote_sync_socket.send(b"READY")
if self.n_remote_reader > 0: if self.n_remote_reader > 0:
# send a message to all remote readers
# to make sure the publish channel is working
self.remote_socket.send(b"READY") self.remote_socket.send(b"READY")
elif self._is_local_reader: elif self._is_local_reader:
self.local_sync_socket.send(b"READY") # wait for the writer to send a message
recv = self.local_sync_socket.recv()
assert recv == b"READY"
recv = self.local_socket.recv() recv = self.local_socket.recv()
assert recv == b"READY" assert recv == b"READY"
elif self._is_remote_reader: elif self._is_remote_reader:
self.remote_sync_socket.send(b"READY") # wait for the writer to send a message
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
recv = self.remote_socket.recv() recv = self.remote_socket.recv()
assert recv == b"READY" assert recv == b"READY"