[core][distributed] fix zmq hang (#6759)
This commit is contained in:
parent
d88c458f44
commit
740374d456
@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Mapping, Optional
|
from typing import Mapping, MutableMapping, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -40,7 +40,7 @@ class HTTPConnection:
|
|||||||
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
|
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
|
||||||
"must have scheme 'http' or 'https'.")
|
"must have scheme 'http' or 'https'.")
|
||||||
|
|
||||||
def _headers(self, **extras: str) -> Mapping[str, str]:
|
def _headers(self, **extras: str) -> MutableMapping[str, str]:
|
||||||
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
|
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
|
||||||
|
|
||||||
def get_response(
|
def get_response(
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from unittest.mock import patch
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
|
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -153,9 +153,7 @@ class Handle:
|
|||||||
|
|
||||||
buffer: Optional[ShmRingBuffer] = None
|
buffer: Optional[ShmRingBuffer] = None
|
||||||
local_subscribe_port: Optional[int] = None
|
local_subscribe_port: Optional[int] = None
|
||||||
local_sync_port: Optional[int] = None
|
|
||||||
remote_subscribe_port: Optional[int] = None
|
remote_subscribe_port: Optional[int] = None
|
||||||
remote_sync_port: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class MessageQueue:
|
class MessageQueue:
|
||||||
@ -189,38 +187,36 @@ class MessageQueue:
|
|||||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||||
max_chunks)
|
max_chunks)
|
||||||
|
|
||||||
self.local_socket = context.socket(PUB)
|
# XPUB is very similar to PUB,
|
||||||
|
# except that it can receive subscription messages
|
||||||
|
# to confirm the number of subscribers
|
||||||
|
self.local_socket = context.socket(XPUB)
|
||||||
|
# set the verbose option so that we can receive every subscription
|
||||||
|
# message. otherwise, we will only receive the first subscription
|
||||||
|
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||||
|
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||||
local_subscribe_port = get_open_port()
|
local_subscribe_port = get_open_port()
|
||||||
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
|
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
|
||||||
|
|
||||||
self.local_sync_socket = context.socket(REP)
|
|
||||||
local_sync_port = get_open_port()
|
|
||||||
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
|
|
||||||
self.current_idx = 0
|
self.current_idx = 0
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.buffer = None # type: ignore
|
self.buffer = None # type: ignore
|
||||||
local_subscribe_port = None
|
local_subscribe_port = None
|
||||||
local_sync_port = None
|
|
||||||
self.local_socket = None
|
self.local_socket = None
|
||||||
self.local_sync_socket = None
|
|
||||||
self.current_idx = -1
|
self.current_idx = -1
|
||||||
|
|
||||||
if n_remote_reader > 0:
|
if n_remote_reader > 0:
|
||||||
# for remote readers, we will:
|
# for remote readers, we will:
|
||||||
# create a publish-subscribe socket to communicate large data
|
# create a publish-subscribe socket to communicate large data
|
||||||
self.remote_socket = context.socket(PUB)
|
self.remote_socket = context.socket(XPUB)
|
||||||
|
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||||
remote_subscribe_port = get_open_port()
|
remote_subscribe_port = get_open_port()
|
||||||
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
|
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
|
||||||
|
|
||||||
self.remote_sync_socket = context.socket(REP)
|
|
||||||
remote_sync_port = get_open_port()
|
|
||||||
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
|
|
||||||
else:
|
else:
|
||||||
remote_subscribe_port = None
|
remote_subscribe_port = None
|
||||||
remote_sync_port = None
|
|
||||||
self.remote_socket = None
|
self.remote_socket = None
|
||||||
self.remote_sync_socket = None
|
|
||||||
|
|
||||||
self._is_writer = True
|
self._is_writer = True
|
||||||
self._is_local_reader = False
|
self._is_local_reader = False
|
||||||
@ -233,9 +229,7 @@ class MessageQueue:
|
|||||||
local_reader_ranks=local_reader_ranks,
|
local_reader_ranks=local_reader_ranks,
|
||||||
buffer=self.buffer,
|
buffer=self.buffer,
|
||||||
local_subscribe_port=local_subscribe_port,
|
local_subscribe_port=local_subscribe_port,
|
||||||
local_sync_port=local_sync_port,
|
|
||||||
remote_subscribe_port=remote_subscribe_port,
|
remote_subscribe_port=remote_subscribe_port,
|
||||||
remote_sync_port=remote_sync_port,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||||
@ -264,12 +258,7 @@ class MessageQueue:
|
|||||||
self.local_socket.connect(
|
self.local_socket.connect(
|
||||||
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
|
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
|
||||||
|
|
||||||
self.local_sync_socket = context.socket(REQ)
|
|
||||||
self.local_sync_socket.connect(
|
|
||||||
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
|
|
||||||
|
|
||||||
self.remote_socket = None
|
self.remote_socket = None
|
||||||
self.remote_sync_socket = None
|
|
||||||
else:
|
else:
|
||||||
self.buffer = None # type: ignore
|
self.buffer = None # type: ignore
|
||||||
self.current_idx = -1
|
self.current_idx = -1
|
||||||
@ -278,17 +267,12 @@ class MessageQueue:
|
|||||||
self._is_remote_reader = True
|
self._is_remote_reader = True
|
||||||
|
|
||||||
self.local_socket = None
|
self.local_socket = None
|
||||||
self.local_sync_socket = None
|
|
||||||
|
|
||||||
self.remote_socket = context.socket(SUB)
|
self.remote_socket = context.socket(SUB)
|
||||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||||
self.remote_socket.connect(
|
self.remote_socket.connect(
|
||||||
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
|
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
|
||||||
|
|
||||||
self.remote_sync_socket = context.socket(REQ)
|
|
||||||
self.remote_sync_socket.connect(
|
|
||||||
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def wait_until_ready(self):
|
def wait_until_ready(self):
|
||||||
@ -300,29 +284,27 @@ class MessageQueue:
|
|||||||
|
|
||||||
# local readers
|
# local readers
|
||||||
for i in range(self.n_local_reader):
|
for i in range(self.n_local_reader):
|
||||||
recv = self.local_sync_socket.recv()
|
# wait for subscription messages from all local readers
|
||||||
assert recv == b"READY"
|
self.local_socket.recv()
|
||||||
self.local_sync_socket.send(b"READY")
|
|
||||||
if self.n_local_reader > 0:
|
if self.n_local_reader > 0:
|
||||||
|
# send a message to all local readers
|
||||||
|
# to make sure the publish channel is working
|
||||||
self.local_socket.send(b"READY")
|
self.local_socket.send(b"READY")
|
||||||
|
|
||||||
# remote readers
|
# remote readers
|
||||||
for i in range(self.n_remote_reader):
|
for i in range(self.n_remote_reader):
|
||||||
recv = self.remote_sync_socket.recv()
|
# wait for subscription messages from all remote readers
|
||||||
assert recv == b"READY"
|
self.remote_socket.recv()
|
||||||
self.remote_sync_socket.send(b"READY")
|
|
||||||
if self.n_remote_reader > 0:
|
if self.n_remote_reader > 0:
|
||||||
|
# send a message to all remote readers
|
||||||
|
# to make sure the publish channel is working
|
||||||
self.remote_socket.send(b"READY")
|
self.remote_socket.send(b"READY")
|
||||||
elif self._is_local_reader:
|
elif self._is_local_reader:
|
||||||
self.local_sync_socket.send(b"READY")
|
# wait for the writer to send a message
|
||||||
recv = self.local_sync_socket.recv()
|
|
||||||
assert recv == b"READY"
|
|
||||||
recv = self.local_socket.recv()
|
recv = self.local_socket.recv()
|
||||||
assert recv == b"READY"
|
assert recv == b"READY"
|
||||||
elif self._is_remote_reader:
|
elif self._is_remote_reader:
|
||||||
self.remote_sync_socket.send(b"READY")
|
# wait for the writer to send a message
|
||||||
recv = self.remote_sync_socket.recv()
|
|
||||||
assert recv == b"READY"
|
|
||||||
recv = self.remote_socket.recv()
|
recv = self.remote_socket.recv()
|
||||||
assert recv == b"READY"
|
assert recv == b"READY"
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user