[Frontend] Kill the server on engine death (#6594)
Signed-off-by: Joe Runde <joe@joerun.de> Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
5fb4a3f678
commit
21b9c49aa3
47
tests/entrypoints/openai/test_shutdown.py
Normal file
47
tests/entrypoints/openai/test_shutdown.py
Normal file
@ -0,0 +1,47 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_on_engine_failure(tmp_path):
|
||||
# Use a bad adapter to crash the engine
|
||||
# (This test will fail when that bug is fixed)
|
||||
adapter_path = tmp_path / "bad_adapter"
|
||||
os.mkdir(adapter_path)
|
||||
with open(adapter_path / "adapter_model_config.json", "w") as f:
|
||||
json.dump({"not": "real"}, f)
|
||||
with open(adapter_path / "adapter_model.safetensors", "wb") as f:
|
||||
f.write(b"this is fake")
|
||||
|
||||
# dtype, max-len etc set so that this can run in CI
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"bad-adapter={tmp_path / 'bad_adapter'}",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
with pytest.raises(openai.APIConnectionError):
|
||||
# This crashes the engine
|
||||
await client.completions.create(model="bad-adapter",
|
||||
prompt="Hello, my name is")
|
||||
|
||||
# Now the server should shut down
|
||||
return_code = remote_server.proc.wait(timeout=1)
|
||||
assert return_code is not None
|
||||
@ -58,7 +58,7 @@ def _log_task_completion(task: asyncio.Task,
|
||||
error_callback(exception)
|
||||
raise AsyncEngineDeadError(
|
||||
"Task finished unexpectedly. This should never happen! "
|
||||
"Please open an issue on Github. See stack trace above for the"
|
||||
"Please open an issue on Github. See stack trace above for the "
|
||||
"actual cause.") from e
|
||||
|
||||
|
||||
@ -132,7 +132,9 @@ class RequestTracker:
|
||||
self._request_streams[request_id].put(exc)
|
||||
self.abort_request(request_id)
|
||||
else:
|
||||
for rid, stream in self._request_streams.items():
|
||||
# NB: list() used here because self.abort_request pops the stream
|
||||
# out of self._request_streams, so we can't iterate on it directly
|
||||
for rid, stream in list(self._request_streams.items()):
|
||||
stream.put(exc)
|
||||
self.abort_request(rid)
|
||||
|
||||
|
||||
@ -118,6 +118,7 @@ async def run_server(args: Namespace,
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
engine=engine,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.log_level,
|
||||
|
||||
@ -1,16 +1,21 @@
|
||||
import asyncio
|
||||
import signal
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
|
||||
async def serve_http(app: FastAPI, engine: AsyncEngineClient,
|
||||
**uvicorn_kwargs: Any):
|
||||
logger.info("Available routes are:")
|
||||
for route in app.routes:
|
||||
methods = getattr(route, "methods", None)
|
||||
@ -23,6 +28,7 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
|
||||
|
||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||
server = uvicorn.Server(config)
|
||||
_add_shutdown_handlers(app, server, engine)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
@ -44,3 +50,37 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Gracefully stopping http server")
|
||||
return server.shutdown()
|
||||
|
||||
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
|
||||
engine: AsyncEngineClient) -> None:
|
||||
"""Adds handlers for fatal errors that should crash the server"""
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
async def runtime_error_handler(_, __):
|
||||
"""On generic runtime error, check to see if the engine has died.
|
||||
It probably has, in which case the server will no longer be able to
|
||||
handle requests. Trigger a graceful shutdown with a SIGTERM."""
|
||||
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
|
||||
and not engine.is_running):
|
||||
logger.fatal("AsyncLLMEngine has failed, terminating server "
|
||||
"process")
|
||||
# See discussions here on shutting down a uvicorn server
|
||||
# https://github.com/encode/uvicorn/discussions/1103
|
||||
# In this case we cannot await the server shutdown here because
|
||||
# this handler must first return to close the connection for
|
||||
# this request.
|
||||
server.should_exit = True
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@app.exception_handler(AsyncEngineDeadError)
|
||||
async def engine_dead_handler(_, __):
|
||||
"""Kill the server if the async engine is already dead. It will
|
||||
not handle any further requests."""
|
||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
||||
logger.fatal("AsyncLLMEngine is already dead, terminating server "
|
||||
"process")
|
||||
server.should_exit = True
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@ -357,6 +357,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
engine=async_engine_client,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
|
||||
@ -33,6 +33,7 @@ class AsyncEngineRPCClient:
|
||||
|
||||
# Wait until server is ready.
|
||||
await self.wait_for_server()
|
||||
self._errored = False
|
||||
|
||||
# Get the configs.
|
||||
self.model_config = await self._get_model_config_rpc()
|
||||
@ -169,7 +170,7 @@ class AsyncEngineRPCClient:
|
||||
expected_type=SchedulerConfig,
|
||||
error_message="Could not get SchedulerConfig from RPC Server")
|
||||
|
||||
async def _get_lora_config_rpc(self):
|
||||
async def _get_lora_config_rpc(self) -> LoRAConfig:
|
||||
"""Get LoRAConfig from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
@ -177,7 +178,7 @@ class AsyncEngineRPCClient:
|
||||
expected_type=LoRAConfig,
|
||||
error_message="Could not get LoRAConfig from RPC Server")
|
||||
|
||||
async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
|
||||
async def _is_tracing_enabled_rpc(self) -> bool:
|
||||
"""Get is_tracing_enabled flag from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
@ -200,6 +201,18 @@ class AsyncEngineRPCClient:
|
||||
request=RPCUtilityRequest.DO_LOG_STATS,
|
||||
error_message="RPCRequest DO_LOG_STATS failed.")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return not self._errored
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self._errored
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self._errored
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
@ -233,6 +246,15 @@ class AsyncEngineRPCClient:
|
||||
request_output = cloudpickle.loads(message)
|
||||
|
||||
if isinstance(request_output, Exception):
|
||||
# On exception, check if the server is still healthy.
|
||||
# Use this to set the sync `is_running` and `errored`
|
||||
# properties.
|
||||
try:
|
||||
await self.check_health()
|
||||
except Exception:
|
||||
self._errored = True
|
||||
# NB: do before raising here so that the flag is set
|
||||
# by the time the caller receives this exception
|
||||
raise request_output
|
||||
|
||||
finished = request_output.finished
|
||||
|
||||
@ -96,14 +96,17 @@ class AsyncEngineRPCServer:
|
||||
|
||||
async def abort(self, identity, request: RPCAbortRequest):
|
||||
"""Abort request and notify the client of success."""
|
||||
# Abort the request in the llm engine.
|
||||
await self.engine.abort(request.request_id)
|
||||
|
||||
# Send confirmation to the client.
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
try:
|
||||
# Abort the request in the llm engine.
|
||||
await self.engine.abort(request.request_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to abort request %s", request.request_id)
|
||||
finally:
|
||||
# Send confirmation to the client.
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
async def generate(self, identity, generate_request: RPCGenerateRequest):
|
||||
try:
|
||||
|
||||
@ -49,6 +49,7 @@ if TYPE_CHECKING:
|
||||
NVCC_THREADS: Optional[str] = None
|
||||
VLLM_USE_PRECOMPILED: bool = False
|
||||
VLLM_NO_DEPRECATION_WARNING: bool = False
|
||||
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
|
||||
CMAKE_BUILD_TYPE: Optional[str] = None
|
||||
VERBOSE: bool = False
|
||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||
@ -335,6 +336,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_NO_DEPRECATION_WARNING":
|
||||
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),
|
||||
|
||||
# If set, the OpenAI API server will stay alive even after the underlying
|
||||
# AsyncLLMEngine errors and stops serving requests
|
||||
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH":
|
||||
lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)),
|
||||
|
||||
# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
|
||||
# the user to specify a max sequence length greater than
|
||||
# the max length derived from the model's config.json.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user