From 981b0d567355063d5453e382a85970cae083c615 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 27 Jul 2024 09:58:25 +0800 Subject: [PATCH] [Frontend] Factor out code for running uvicorn (#6828) --- vllm/entrypoints/api_server.py | 74 ++++++++++++++++++--------- vllm/entrypoints/openai/api_server.py | 72 ++++++++------------------ vllm/server/__init__.py | 3 ++ vllm/server/launch.py | 42 +++++++++++++++ 4 files changed, 116 insertions(+), 75 deletions(-) create mode 100644 vllm/server/__init__.py create mode 100644 vllm/server/launch.py diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 66941442..34763576 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -5,12 +5,12 @@ For production use, we recommend using our OpenAI compatible server. We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. """ - +import asyncio import json import ssl -from typing import AsyncGenerator +from argparse import Namespace +from typing import Any, AsyncGenerator, Optional -import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -18,8 +18,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.logger import init_logger from vllm.sampling_params import SamplingParams +from vllm.server import serve_http from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -81,6 +83,50 @@ async def generate(request: Request) -> Response: return JSONResponse(ret) +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + app = await init_app(args, llm_engine) + await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) @@ -105,25 +151,5 @@ if __name__ == "__main__": parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER) - app.root_path = args.root_path - - logger.info("Available routes are:") - for route in app.routes: - if not hasattr(route, 'methods'): - continue - methods = ', '.join(route.methods) - logger.info("Route: %s, Methods: %s", route.path, methods) - - uvicorn.run(app, - host=args.host, - port=args.port, - log_level=args.log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs) + asyncio.run(run_server(args)) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0fe4dd24..c1640a10 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,14 +2,12 @@ import asyncio import importlib import inspect import re -import signal +from argparse import Namespace from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Optional, Set +from typing import Any, Optional, Set -import fastapi -import uvicorn -from fastapi import APIRouter, Request +from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -38,6 +36,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger +from vllm.server import serve_http from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION @@ -57,7 +56,7 @@ _running_tasks: Set[asyncio.Task] = set() @asynccontextmanager -async def lifespan(app: fastapi.FastAPI): +async def lifespan(app: FastAPI): async def _force_log(): while True: @@ -75,7 +74,7 @@ async def lifespan(app: fastapi.FastAPI): router = APIRouter() -def mount_metrics(app: fastapi.FastAPI): +def mount_metrics(app: FastAPI): # Add prometheus asgi middleware to route /metrics requests metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics @@ -165,8 +164,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -def build_app(args): - app = fastapi.FastAPI(lifespan=lifespan) +def build_app(args: Namespace) -> FastAPI: + app = FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path @@ -214,11 +213,8 @@ def build_app(args): return app -async def build_server( - args, - llm_engine: Optional[AsyncLLMEngine] = None, - **uvicorn_kwargs, -) -> uvicorn.Server: +async def init_app(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None) -> FastAPI: app = build_app(args) if args.served_model_name is not None: @@ -281,14 +277,17 @@ async def build_server( ) app.root_path = args.root_path - logger.info("Available routes are:") - for route in app.routes: - if not hasattr(route, 'methods'): - continue - methods = ', '.join(route.methods) - logger.info("Route: %s, Methods: %s", route.path, methods) + return app - config = uvicorn.Config( + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + app = await init_app(args, llm_engine) + await serve_http( app, host=args.host, port=args.port, @@ -301,36 +300,6 @@ async def build_server( **uvicorn_kwargs, ) - return uvicorn.Server(config) - - -async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None: - logger.info("vLLM API server version %s", VLLM_VERSION) - logger.info("args: %s", args) - - server = await build_server( - args, - llm_engine, - **uvicorn_kwargs, - ) - - loop = asyncio.get_running_loop() - - server_task = loop.create_task(server.serve()) - - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - print("Gracefully stopping http server") - await server.shutdown() - if __name__ == "__main__": # NOTE(simon): @@ -339,4 +308,5 @@ if __name__ == "__main__": description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) args = parser.parse_args() + asyncio.run(run_server(args)) diff --git a/vllm/server/__init__.py b/vllm/server/__init__.py new file mode 100644 index 00000000..17c98b4d --- /dev/null +++ b/vllm/server/__init__.py @@ -0,0 +1,3 @@ +from .launch import serve_http + +__all__ = ["serve_http"] diff --git a/vllm/server/launch.py b/vllm/server/launch.py new file mode 100644 index 00000000..1a8aeb7f --- /dev/null +++ b/vllm/server/launch.py @@ -0,0 +1,42 @@ +import asyncio +import signal +from typing import Any + +import uvicorn +from fastapi import FastAPI + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any) -> None: + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + server = uvicorn.Server(config) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + await server.shutdown()