[Frontend] split run_server into build_server and run_server (#6740)

This commit is contained in:
Daniele 2024-07-24 19:36:04 +02:00 committed by GitHub
parent 40468b13fa
commit ee812580f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 28 deletions

View File

@ -2,6 +2,7 @@ import asyncio
import importlib
import inspect
import re
import signal
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Optional, Set
@ -213,12 +214,13 @@ def build_app(args):
return app
def run_server(args, llm_engine=None):
async def build_server(
args,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs,
) -> uvicorn.Server:
app = build_app(args)
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
@ -231,19 +233,7 @@ def run_server(args, llm_engine=None):
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
event_loop: Optional[asyncio.AbstractEventLoop]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config = event_loop.run_until_complete(engine.get_model_config())
else:
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
model_config = await engine.get_model_config()
if args.disable_log_requests:
request_logger = None
@ -296,7 +286,8 @@ def run_server(args, llm_engine=None):
methods = ', '.join(route.methods)
logger.info("Route: %s, Methods: %s", route.path, methods)
uvicorn.run(app,
config = uvicorn.Config(
app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
@ -304,7 +295,39 @@ def run_server(args, llm_engine=None):
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
return uvicorn.Server(config)
async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
server = await build_server(
args,
llm_engine,
**uvicorn_kwargs,
)
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.serve())
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
print("Gracefully stopping http server")
await server.shutdown()
if __name__ == "__main__":
@ -314,4 +337,4 @@ if __name__ == "__main__":
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
run_server(args)
asyncio.run(run_server(args))

View File

@ -1,5 +1,6 @@
# The CLI entrypoint to vLLM.
import argparse
import asyncio
import os
import signal
import sys
@ -25,7 +26,7 @@ def serve(args: argparse.Namespace) -> None:
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
run_server(args)
asyncio.run(run_server(args))
def interactive_cli(args: argparse.Namespace) -> None: