[Frontend] Minor optimizations to zmq decoupled front-end (#7957)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
parent
af59df0a10
commit
4289cad37f
@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
import pickle
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
@ -115,18 +117,21 @@ class AsyncEngineRPCClient:
|
||||
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
|
||||
|
||||
# IPC connection to RPC Server (uses unix sockets).
|
||||
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
|
||||
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
|
||||
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
self.to_rpc_server.bind(rpc_path)
|
||||
|
||||
# In process proxy to RPC Server (uses memory-based messaging).
|
||||
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
|
||||
self.from_api_server: Socket = self.context.socket(
|
||||
zmq.constants.ROUTER)
|
||||
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
self.from_api_server.bind(INPROC_PROXY_PATH)
|
||||
|
||||
# Asyncio background task for the proxy.
|
||||
self.proxy_task = asyncio.create_task(
|
||||
self.proxy_in_task = asyncio.create_task(
|
||||
self.run_proxy(self.from_api_server, self.to_rpc_server))
|
||||
self.proxy_out_task = asyncio.create_task(
|
||||
self.run_proxy(self.to_rpc_server, self.from_api_server))
|
||||
|
||||
# Since we open 1 inproc socket per request, we have a hard cap on
|
||||
# the number of requests that can run in vLLM w. frontend
|
||||
@ -136,20 +141,11 @@ class AsyncEngineRPCClient:
|
||||
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
|
||||
self.limit_concurrency = socket_limit // 2 - 2
|
||||
|
||||
async def run_proxy(self, socket_from, socket_to):
|
||||
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
|
||||
"""Background task that runs a proxy"""
|
||||
poller = zmq.asyncio.Poller()
|
||||
poller.register(socket_from, zmq.constants.POLLIN)
|
||||
poller.register(socket_to, zmq.constants.POLLIN)
|
||||
while True:
|
||||
events_lst = await poller.poll()
|
||||
events = dict(events_lst)
|
||||
if socket_from in events:
|
||||
identity, msg = await socket_from.recv_multipart()
|
||||
await socket_to.send_multipart([identity, msg])
|
||||
if socket_to in events:
|
||||
identity, msg = await socket_to.recv_multipart()
|
||||
await socket_from.send_multipart([identity, msg])
|
||||
frames = await socket_from.recv_multipart(copy=False)
|
||||
await socket_to.send_multipart(frames, copy=False)
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the client before it starts sending server requests."""
|
||||
@ -180,7 +176,7 @@ class AsyncEngineRPCClient:
|
||||
self.context.destroy()
|
||||
|
||||
@contextmanager
|
||||
def to_proxy_socket(self):
|
||||
def to_proxy_socket(self) -> Iterator[Socket]:
|
||||
# Connect to the RPCServer via the proxy.
|
||||
|
||||
# Raise a sensible error if the client was already closed.
|
||||
@ -208,7 +204,8 @@ class AsyncEngineRPCClient:
|
||||
|
||||
with self.to_proxy_socket() as socket:
|
||||
# Ping RPCServer with a request.
|
||||
await socket.send_multipart([cloudpickle.dumps(request)])
|
||||
await socket.send_multipart((cloudpickle.dumps(request), ),
|
||||
copy=False)
|
||||
|
||||
# Make sure the server responds
|
||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
||||
@ -216,7 +213,8 @@ class AsyncEngineRPCClient:
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
# Await the data from the Server.
|
||||
data = cloudpickle.loads(await socket.recv())
|
||||
frame = await socket.recv(copy=False)
|
||||
data = pickle.loads(frame.buffer)
|
||||
|
||||
if isinstance(data, Exception):
|
||||
# Re-raise exceptions returned by the server
|
||||
@ -234,23 +232,22 @@ class AsyncEngineRPCClient:
|
||||
|
||||
return data
|
||||
|
||||
async def _send_one_way_rpc_request(
|
||||
self,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
error_message: str,
|
||||
socket: Optional[zmq.asyncio.Socket] = None):
|
||||
async def _send_one_way_rpc_request(self,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
error_message: str,
|
||||
socket: Optional[Socket] = None):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
|
||||
async def do_rpc_call(socket: zmq.asyncio.Socket,
|
||||
request: RPC_REQUEST_TYPE):
|
||||
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
|
||||
|
||||
await socket.send_multipart([cloudpickle.dumps(request)])
|
||||
await socket.send_multipart((cloudpickle.dumps(request), ))
|
||||
|
||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
||||
raise TimeoutError("Server didn't reply within "
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
return cloudpickle.loads(await socket.recv())
|
||||
frame = await socket.recv(copy=False)
|
||||
return pickle.loads(frame.buffer)
|
||||
|
||||
# Make a new socket connection.
|
||||
if socket is None:
|
||||
@ -386,21 +383,19 @@ class AsyncEngineRPCClient:
|
||||
try:
|
||||
with self.to_proxy_socket() as socket:
|
||||
# Send RPCGenerateRequest to the RPCServer.
|
||||
await socket.send_multipart([
|
||||
cloudpickle.dumps(
|
||||
RPCGenerateRequest(
|
||||
inputs=inputs,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request))
|
||||
])
|
||||
await socket.send_multipart((cloudpickle.dumps(
|
||||
RPCGenerateRequest(
|
||||
inputs=inputs,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request)), ))
|
||||
|
||||
# Stream back the results from the RPC Server.
|
||||
while not finished:
|
||||
message = await socket.recv()
|
||||
request_output = cloudpickle.loads(message)
|
||||
message = await socket.recv(copy=False)
|
||||
request_output = pickle.loads(message.buffer)
|
||||
|
||||
if isinstance(request_output, Exception):
|
||||
# On exception, check if the server is still healthy
|
||||
@ -424,9 +419,7 @@ class AsyncEngineRPCClient:
|
||||
if not finished and not self._errored:
|
||||
await self.abort(request_id)
|
||||
|
||||
async def check_health(self,
|
||||
socket: Optional[zmq.asyncio.Socket] = None
|
||||
) -> None:
|
||||
async def check_health(self, socket: Optional[Socket] = None) -> None:
|
||||
"""Raise if unhealthy"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
@ -451,4 +444,4 @@ class AsyncEngineRPCClient:
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.STOP_PROFILE,
|
||||
error_message="RPCRequest STOP_PROFILE failed.")
|
||||
error_message="RPCRequest STOP_PROFILE failed.")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import pickle
|
||||
import signal
|
||||
from typing import Any, Coroutine, Union
|
||||
|
||||
@ -7,6 +8,8 @@ import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from typing_extensions import Never
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
@ -35,7 +38,7 @@ class AsyncEngineRPCServer:
|
||||
self.context = zmq.asyncio.Context()
|
||||
|
||||
# Init socket.
|
||||
self.socket = self.context.socket(zmq.constants.DEALER)
|
||||
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
|
||||
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
self.socket.connect(rpc_path)
|
||||
|
||||
@ -63,30 +66,31 @@ class AsyncEngineRPCServer:
|
||||
else:
|
||||
raise ValueError("Unknown Config Request: %s", request)
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(config)])
|
||||
await self.socket.send_multipart((identity, pickle.dumps(config)),
|
||||
copy=False)
|
||||
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
await self.socket.send_multipart((identity, pickle.dumps(e)),
|
||||
copy=False)
|
||||
|
||||
async def is_tracing_enabled(self, identity):
|
||||
"""Send the is_tracing_enabled flag"""
|
||||
tracing_flag = await self.engine.is_tracing_enabled()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(tracing_flag)])
|
||||
(identity, pickle.dumps(tracing_flag)))
|
||||
|
||||
async def do_log_stats(self, identity):
|
||||
"""Log stats and confirm success."""
|
||||
await self.engine.do_log_stats()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
|
||||
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
|
||||
|
||||
async def is_server_ready(self, identity):
|
||||
"""Notify the client that we are ready."""
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
|
||||
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
|
||||
|
||||
async def abort(self, identity, request: RPCAbortRequest):
|
||||
"""Abort request and notify the client of success."""
|
||||
@ -96,7 +100,7 @@ class AsyncEngineRPCServer:
|
||||
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
|
||||
except Exception as e:
|
||||
result = e
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
|
||||
await self.socket.send_multipart((identity, pickle.dumps(result)))
|
||||
|
||||
async def generate(self, identity, generate_request: RPCGenerateRequest):
|
||||
try:
|
||||
@ -110,45 +114,47 @@ class AsyncEngineRPCServer:
|
||||
|
||||
async for request_output in results_generator:
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(request_output)])
|
||||
(identity, pickle.dumps(request_output)), copy=False)
|
||||
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
await self.socket.send_multipart((identity, pickle.dumps(e)),
|
||||
copy=False)
|
||||
|
||||
async def check_health(self, identity):
|
||||
try:
|
||||
await self.engine.check_health()
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
|
||||
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
|
||||
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
await self.socket.send_multipart((identity, pickle.dumps(e)),
|
||||
copy=False)
|
||||
|
||||
async def start_profile(self, identity):
|
||||
logger.info("Starting profiler...")
|
||||
await self.engine.start_profile()
|
||||
logger.info("Profiler started.")
|
||||
|
||||
await self.socket.send_multipart([
|
||||
await self.socket.send_multipart((
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
pickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
))
|
||||
|
||||
async def stop_profile(self, identity):
|
||||
logger.info("Stopping profiler...")
|
||||
await self.engine.stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
|
||||
await self.socket.send_multipart([
|
||||
await self.socket.send_multipart((
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
pickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
))
|
||||
|
||||
def _make_handler_coro(self, identity,
|
||||
message) -> Coroutine[Any, Any, Never]:
|
||||
message: Frame) -> Coroutine[Any, Any, Never]:
|
||||
"""Route the zmq message to the handler coroutine."""
|
||||
|
||||
request = cloudpickle.loads(message)
|
||||
request = cloudpickle.loads(message.buffer)
|
||||
|
||||
if isinstance(request, RPCGenerateRequest):
|
||||
return self.generate(identity, request)
|
||||
@ -189,7 +195,7 @@ class AsyncEngineRPCServer:
|
||||
running_tasks = set()
|
||||
while True:
|
||||
# Wait for a request.
|
||||
identity, message = await self.socket.recv_multipart()
|
||||
identity, message = await self.socket.recv_multipart(copy=False)
|
||||
|
||||
# Process the request async.
|
||||
task = asyncio.create_task(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user