[BugFix] Fix frontend multiprocessing hang (#7217)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
parent
0e12cd67a8
commit
fde47d3bc2
35
tests/entrypoints/openai/test_mp_crash.py
Normal file
35
tests/entrypoints/openai/test_mp_crash.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def crashing_from_engine_args(
|
||||
cls,
|
||||
engine_args: Any = None,
|
||||
start_engine_loop: Any = None,
|
||||
usage_context: Any = None,
|
||||
stat_loggers: Any = None,
|
||||
) -> "AsyncLLMEngine":
|
||||
raise Exception("foo")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mp_crash_detection(monkeypatch):
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo, monkeypatch.context() as m:
|
||||
m.setattr(AsyncLLMEngine, "from_engine_args",
|
||||
crashing_from_engine_args)
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args([])
|
||||
|
||||
async with build_async_engine_client(args):
|
||||
pass
|
||||
assert "The server process died before responding to the readiness probe"\
|
||||
in str(excinfo.value)
|
||||
@ -120,9 +120,18 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
|
||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||
async_engine_client = AsyncEngineRPCClient(rpc_path)
|
||||
await async_engine_client.setup()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await async_engine_client.setup()
|
||||
break
|
||||
except TimeoutError as e:
|
||||
if not rpc_server_process.is_alive():
|
||||
raise RuntimeError(
|
||||
"The server process died before "
|
||||
"responding to the readiness probe") from e
|
||||
|
||||
yield async_engine_client
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
|
||||
@ -18,6 +18,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
# Time to wait before checking it the server process is alive.
|
||||
SERVER_START_TIMEOUT_MS = 1000
|
||||
|
||||
|
||||
class AsyncEngineRPCClient:
|
||||
|
||||
@ -61,7 +64,16 @@ class AsyncEngineRPCClient:
|
||||
socket.connect(self.rpc_path)
|
||||
yield socket
|
||||
finally:
|
||||
socket.close()
|
||||
# linger == 0 means discard unsent messages
|
||||
# when the socket is closed. This is necessary
|
||||
# because otherwise self.context.destroy() will
|
||||
# wait for 30 seconds until unsent messages are
|
||||
# received, which is impossible if the server
|
||||
# crashed. In the absence of a server crash we
|
||||
# always expect a response before closing the
|
||||
# socket anyway.
|
||||
# Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
|
||||
socket.close(linger=0)
|
||||
|
||||
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
|
||||
expected_type: Any,
|
||||
@ -85,14 +97,19 @@ class AsyncEngineRPCClient:
|
||||
|
||||
return data
|
||||
|
||||
async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
|
||||
error_message: str):
|
||||
async def _send_one_way_rpc_request(self,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
error_message: str,
|
||||
timeout: Optional[int] = None):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
with self.socket() as socket:
|
||||
# Ping RPC Server with request.
|
||||
await socket.send(cloudpickle.dumps(request))
|
||||
|
||||
# Await acknowledgement from RPCServer.
|
||||
if timeout is not None and await socket.poll(timeout=timeout) == 0:
|
||||
raise TimeoutError(f"server didn't reply within {timeout} ms")
|
||||
|
||||
response = cloudpickle.loads(await socket.recv())
|
||||
|
||||
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
|
||||
@ -117,7 +134,8 @@ class AsyncEngineRPCClient:
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_READY,
|
||||
error_message="Unable to start RPC Server.")
|
||||
error_message="Unable to start RPC Server.",
|
||||
timeout=SERVER_START_TIMEOUT_MS)
|
||||
|
||||
async def _get_model_config_rpc(self) -> ModelConfig:
|
||||
"""Get the ModelConfig object from the RPC Server"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user