[Frontend] Minor optimizations to zmq decoupled front-end (#7957)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
Nick Hill 2024-08-28 17:22:43 -07:00 committed by GitHub
parent af59df0a10
commit 4289cad37f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 65 deletions

View File

@ -1,11 +1,13 @@
import asyncio import asyncio
import pickle
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
from uuid import uuid4 from uuid import uuid4
import cloudpickle import cloudpickle
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from zmq.asyncio import Socket
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
@ -115,18 +117,21 @@ class AsyncEngineRPCClient:
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
# IPC connection to RPC Server (uses unix sockets). # IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server = self.context.socket(zmq.constants.DEALER) self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path) self.to_rpc_server.bind(rpc_path)
# In process proxy to RPC Server (uses memory-based messaging). # In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server = self.context.socket(zmq.constants.ROUTER) self.from_api_server: Socket = self.context.socket(
zmq.constants.ROUTER)
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH) self.from_api_server.bind(INPROC_PROXY_PATH)
# Asyncio background task for the proxy. # Asyncio background task for the proxy.
self.proxy_task = asyncio.create_task( self.proxy_in_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server)) self.run_proxy(self.from_api_server, self.to_rpc_server))
self.proxy_out_task = asyncio.create_task(
self.run_proxy(self.to_rpc_server, self.from_api_server))
# Since we open 1 inproc socket per request, we have a hard cap on # Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend # the number of requests that can run in vLLM w. frontend
@ -136,20 +141,11 @@ class AsyncEngineRPCClient:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health() # 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2 self.limit_concurrency = socket_limit // 2 - 2
async def run_proxy(self, socket_from, socket_to): async def run_proxy(self, socket_from: Socket, socket_to: Socket):
"""Background task that runs a proxy""" """Background task that runs a proxy"""
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True: while True:
events_lst = await poller.poll() frames = await socket_from.recv_multipart(copy=False)
events = dict(events_lst) await socket_to.send_multipart(frames, copy=False)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
if socket_to in events:
identity, msg = await socket_to.recv_multipart()
await socket_from.send_multipart([identity, msg])
async def setup(self): async def setup(self):
"""Setup the client before it starts sending server requests.""" """Setup the client before it starts sending server requests."""
@ -180,7 +176,7 @@ class AsyncEngineRPCClient:
self.context.destroy() self.context.destroy()
@contextmanager @contextmanager
def to_proxy_socket(self): def to_proxy_socket(self) -> Iterator[Socket]:
# Connect to the RPCServer via the proxy. # Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed. # Raise a sensible error if the client was already closed.
@ -208,7 +204,8 @@ class AsyncEngineRPCClient:
with self.to_proxy_socket() as socket: with self.to_proxy_socket() as socket:
# Ping RPCServer with a request. # Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)]) await socket.send_multipart((cloudpickle.dumps(request), ),
copy=False)
# Make sure the server responds # Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0: if await socket.poll(timeout=self._data_timeout) == 0:
@ -216,7 +213,8 @@ class AsyncEngineRPCClient:
f"{self._data_timeout} ms") f"{self._data_timeout} ms")
# Await the data from the Server. # Await the data from the Server.
data = cloudpickle.loads(await socket.recv()) frame = await socket.recv(copy=False)
data = pickle.loads(frame.buffer)
if isinstance(data, Exception): if isinstance(data, Exception):
# Re-raise exceptions returned by the server # Re-raise exceptions returned by the server
@ -234,23 +232,22 @@ class AsyncEngineRPCClient:
return data return data
async def _send_one_way_rpc_request( async def _send_one_way_rpc_request(self,
self, request: RPC_REQUEST_TYPE,
request: RPC_REQUEST_TYPE, error_message: str,
error_message: str, socket: Optional[Socket] = None):
socket: Optional[zmq.asyncio.Socket] = None):
"""Send one-way RPC request to trigger an action.""" """Send one-way RPC request to trigger an action."""
async def do_rpc_call(socket: zmq.asyncio.Socket, async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
request: RPC_REQUEST_TYPE):
await socket.send_multipart([cloudpickle.dumps(request)]) await socket.send_multipart((cloudpickle.dumps(request), ))
if await socket.poll(timeout=self._data_timeout) == 0: if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within " raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms") f"{self._data_timeout} ms")
return cloudpickle.loads(await socket.recv()) frame = await socket.recv(copy=False)
return pickle.loads(frame.buffer)
# Make a new socket connection. # Make a new socket connection.
if socket is None: if socket is None:
@ -386,21 +383,19 @@ class AsyncEngineRPCClient:
try: try:
with self.to_proxy_socket() as socket: with self.to_proxy_socket() as socket:
# Send RPCGenerateRequest to the RPCServer. # Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([ await socket.send_multipart((cloudpickle.dumps(
cloudpickle.dumps( RPCGenerateRequest(
RPCGenerateRequest( inputs=inputs,
inputs=inputs, sampling_params=sampling_params,
sampling_params=sampling_params, request_id=request_id,
request_id=request_id, lora_request=lora_request,
lora_request=lora_request, trace_headers=trace_headers,
trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request)), ))
prompt_adapter_request=prompt_adapter_request))
])
# Stream back the results from the RPC Server. # Stream back the results from the RPC Server.
while not finished: while not finished:
message = await socket.recv() message = await socket.recv(copy=False)
request_output = cloudpickle.loads(message) request_output = pickle.loads(message.buffer)
if isinstance(request_output, Exception): if isinstance(request_output, Exception):
# On exception, check if the server is still healthy # On exception, check if the server is still healthy
@ -424,9 +419,7 @@ class AsyncEngineRPCClient:
if not finished and not self._errored: if not finished and not self._errored:
await self.abort(request_id) await self.abort(request_id)
async def check_health(self, async def check_health(self, socket: Optional[Socket] = None) -> None:
socket: Optional[zmq.asyncio.Socket] = None
) -> None:
"""Raise if unhealthy""" """Raise if unhealthy"""
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
@ -451,4 +444,4 @@ class AsyncEngineRPCClient:
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE, request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.") error_message="RPCRequest STOP_PROFILE failed.")

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import pickle
import signal import signal
from typing import Any, Coroutine, Union from typing import Any, Coroutine, Union
@ -7,6 +8,8 @@ import uvloop
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from typing_extensions import Never from typing_extensions import Never
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import AsyncEngineArgs, AsyncLLMEngine from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
@ -35,7 +38,7 @@ class AsyncEngineRPCServer:
self.context = zmq.asyncio.Context() self.context = zmq.asyncio.Context()
# Init socket. # Init socket.
self.socket = self.context.socket(zmq.constants.DEALER) self.socket: Socket = self.context.socket(zmq.constants.DEALER)
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
self.socket.connect(rpc_path) self.socket.connect(rpc_path)
@ -63,30 +66,31 @@ class AsyncEngineRPCServer:
else: else:
raise ValueError("Unknown Config Request: %s", request) raise ValueError("Unknown Config Request: %s", request)
await self.socket.send_multipart( await self.socket.send_multipart((identity, pickle.dumps(config)),
[identity, cloudpickle.dumps(config)]) copy=False)
except Exception as e: except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def is_tracing_enabled(self, identity): async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag""" """Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled() tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart( await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)]) (identity, pickle.dumps(tracing_flag)))
async def do_log_stats(self, identity): async def do_log_stats(self, identity):
"""Log stats and confirm success.""" """Log stats and confirm success."""
await self.engine.do_log_stats() await self.engine.do_log_stats()
await self.socket.send_multipart( await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def is_server_ready(self, identity): async def is_server_ready(self, identity):
"""Notify the client that we are ready.""" """Notify the client that we are ready."""
await self.socket.send_multipart( await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def abort(self, identity, request: RPCAbortRequest): async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success.""" """Abort request and notify the client of success."""
@ -96,7 +100,7 @@ class AsyncEngineRPCServer:
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
except Exception as e: except Exception as e:
result = e result = e
await self.socket.send_multipart([identity, cloudpickle.dumps(result)]) await self.socket.send_multipart((identity, pickle.dumps(result)))
async def generate(self, identity, generate_request: RPCGenerateRequest): async def generate(self, identity, generate_request: RPCGenerateRequest):
try: try:
@ -110,45 +114,47 @@ class AsyncEngineRPCServer:
async for request_output in results_generator: async for request_output in results_generator:
await self.socket.send_multipart( await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)]) (identity, pickle.dumps(request_output)), copy=False)
except Exception as e: except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def check_health(self, identity): async def check_health(self, identity):
try: try:
await self.engine.check_health() await self.engine.check_health()
await self.socket.send_multipart( await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
except Exception as e: except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def start_profile(self, identity): async def start_profile(self, identity):
logger.info("Starting profiler...") logger.info("Starting profiler...")
await self.engine.start_profile() await self.engine.start_profile()
logger.info("Profiler started.") logger.info("Profiler started.")
await self.socket.send_multipart([ await self.socket.send_multipart((
identity, identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), pickle.dumps(VLLM_RPC_SUCCESS_STR),
]) ))
async def stop_profile(self, identity): async def stop_profile(self, identity):
logger.info("Stopping profiler...") logger.info("Stopping profiler...")
await self.engine.stop_profile() await self.engine.stop_profile()
logger.info("Profiler stopped.") logger.info("Profiler stopped.")
await self.socket.send_multipart([ await self.socket.send_multipart((
identity, identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), pickle.dumps(VLLM_RPC_SUCCESS_STR),
]) ))
def _make_handler_coro(self, identity, def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]: message: Frame) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine.""" """Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message) request = cloudpickle.loads(message.buffer)
if isinstance(request, RPCGenerateRequest): if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request) return self.generate(identity, request)
@ -189,7 +195,7 @@ class AsyncEngineRPCServer:
running_tasks = set() running_tasks = set()
while True: while True:
# Wait for a request. # Wait for a request.
identity, message = await self.socket.recv_multipart() identity, message = await self.socket.recv_multipart(copy=False)
# Process the request async. # Process the request async.
task = asyncio.create_task( task = asyncio.create_task(