diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 931063d9..add5c919 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,6 +2,7 @@ import asyncio import importlib import inspect import re +import signal from contextlib import asynccontextmanager from http import HTTPStatus from typing import Optional, Set @@ -213,12 +214,13 @@ def build_app(args): return app -def run_server(args, llm_engine=None): +async def build_server( + args, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs, +) -> uvicorn.Server: app = build_app(args) - logger.info("vLLM API server version %s", VLLM_VERSION) - logger.info("args: %s", args) - if args.served_model_name is not None: served_model_names = args.served_model_name else: @@ -231,19 +233,7 @@ def run_server(args, llm_engine=None): if llm_engine is not None else AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) - event_loop: Optional[asyncio.AbstractEventLoop] - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - event_loop = None - - if event_loop is not None and event_loop.is_running(): - # If the current is instanced by Ray Serve, - # there is already a running event loop - model_config = event_loop.run_until_complete(engine.get_model_config()) - else: - # When using single vLLM without engine_use_ray - model_config = asyncio.run(engine.get_model_config()) + model_config = await engine.get_model_config() if args.disable_log_requests: request_logger = None @@ -296,15 +286,48 @@ def run_server(args, llm_engine=None): methods = ', '.join(route.methods) logger.info("Route: %s, Methods: %s", route.path, methods) - uvicorn.run(app, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs) + config = uvicorn.Config( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + return uvicorn.Server(config) + + +async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + server = await build_server( + args, + llm_engine, + **uvicorn_kwargs, + ) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + except asyncio.CancelledError: + print("Gracefully stopping http server") + await server.shutdown() if __name__ == "__main__": @@ -314,4 +337,4 @@ if __name__ == "__main__": description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) args = parser.parse_args() - run_server(args) + asyncio.run(run_server(args)) diff --git a/vllm/scripts.py b/vllm/scripts.py index 3f334be9..aefa5cec 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -1,5 +1,6 @@ # The CLI entrypoint to vLLM. import argparse +import asyncio import os import signal import sys @@ -25,7 +26,7 @@ def serve(args: argparse.Namespace) -> None: # EngineArgs expects the model name to be passed as --model. args.model = args.model_tag - run_server(args) + asyncio.run(run_server(args)) def interactive_cli(args: argparse.Namespace) -> None: