[Frontend] Minor optimizations to zmq decoupled front-end (#7957)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
Nick Hill 2024-08-28 17:22:43 -07:00 committed by GitHub
parent af59df0a10
commit 4289cad37f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 65 deletions

View File

@ -1,11 +1,13 @@
import asyncio
import pickle
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional
from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
from uuid import uuid4
import cloudpickle
import zmq
import zmq.asyncio
from zmq.asyncio import Socket
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
@ -115,18 +117,21 @@ class AsyncEngineRPCClient:
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)
# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
self.from_api_server: Socket = self.context.socket(
zmq.constants.ROUTER)
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)
# Asyncio background task for the proxy.
self.proxy_task = asyncio.create_task(
self.proxy_in_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
self.proxy_out_task = asyncio.create_task(
self.run_proxy(self.to_rpc_server, self.from_api_server))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
@ -136,20 +141,11 @@ class AsyncEngineRPCClient:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2
async def run_proxy(self, socket_from, socket_to):
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
"""Background task that runs a proxy"""
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events_lst = await poller.poll()
events = dict(events_lst)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
if socket_to in events:
identity, msg = await socket_to.recv_multipart()
await socket_from.send_multipart([identity, msg])
frames = await socket_from.recv_multipart(copy=False)
await socket_to.send_multipart(frames, copy=False)
async def setup(self):
"""Setup the client before it starts sending server requests."""
@ -180,7 +176,7 @@ class AsyncEngineRPCClient:
self.context.destroy()
@contextmanager
def to_proxy_socket(self):
def to_proxy_socket(self) -> Iterator[Socket]:
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
@ -208,7 +204,8 @@ class AsyncEngineRPCClient:
with self.to_proxy_socket() as socket:
# Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ),
copy=False)
# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
@ -216,7 +213,8 @@ class AsyncEngineRPCClient:
f"{self._data_timeout} ms")
# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
data = pickle.loads(frame.buffer)
if isinstance(data, Exception):
# Re-raise exceptions returned by the server
@ -234,23 +232,22 @@ class AsyncEngineRPCClient:
return data
async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[zmq.asyncio.Socket] = None):
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[Socket] = None):
"""Send one-way RPC request to trigger an action."""
async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE):
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ))
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
return cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
return pickle.loads(frame.buffer)
# Make a new socket connection.
if socket is None:
@ -386,21 +383,19 @@ class AsyncEngineRPCClient:
try:
with self.to_proxy_socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
])
await socket.send_multipart((cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)), ))
# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv()
request_output = cloudpickle.loads(message)
message = await socket.recv(copy=False)
request_output = pickle.loads(message.buffer)
if isinstance(request_output, Exception):
# On exception, check if the server is still healthy
@ -424,9 +419,7 @@ class AsyncEngineRPCClient:
if not finished and not self._errored:
await self.abort(request_id)
async def check_health(self,
socket: Optional[zmq.asyncio.Socket] = None
) -> None:
async def check_health(self, socket: Optional[Socket] = None) -> None:
"""Raise if unhealthy"""
await self._send_one_way_rpc_request(
@ -451,4 +444,4 @@ class AsyncEngineRPCClient:
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")
error_message="RPCRequest STOP_PROFILE failed.")

View File

@ -1,4 +1,5 @@
import asyncio
import pickle
import signal
from typing import Any, Coroutine, Union
@ -7,6 +8,8 @@ import uvloop
import zmq
import zmq.asyncio
from typing_extensions import Never
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
@ -35,7 +38,7 @@ class AsyncEngineRPCServer:
self.context = zmq.asyncio.Context()
# Init socket.
self.socket = self.context.socket(zmq.constants.DEALER)
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
self.socket.connect(rpc_path)
@ -63,30 +66,31 @@ class AsyncEngineRPCServer:
else:
raise ValueError("Unknown Config Request: %s", request)
await self.socket.send_multipart(
[identity, cloudpickle.dumps(config)])
await self.socket.send_multipart((identity, pickle.dumps(config)),
copy=False)
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
(identity, pickle.dumps(tracing_flag)))
async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
@ -96,7 +100,7 @@ class AsyncEngineRPCServer:
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
except Exception as e:
result = e
await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
await self.socket.send_multipart((identity, pickle.dumps(result)))
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
@ -110,45 +114,47 @@ class AsyncEngineRPCServer:
async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
(identity, pickle.dumps(request_output)), copy=False)
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")
await self.socket.send_multipart([
await self.socket.send_multipart((
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))
async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")
await self.socket.send_multipart([
await self.socket.send_multipart((
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))
def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
message: Frame) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message)
request = cloudpickle.loads(message.buffer)
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
@ -189,7 +195,7 @@ class AsyncEngineRPCServer:
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
identity, message = await self.socket.recv_multipart(copy=False)
# Process the request async.
task = asyncio.create_task(