From 98fe8cb5420c28fa8dcc3110b6c898848dd57e45 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 3 Jul 2023 23:01:56 -0700 Subject: [PATCH] [Server] Add option to specify chat template for chat endpoint (#345) --- requirements.txt | 1 + vllm/entrypoints/openai/api_server.py | 28 +++++++++++++++++++++------ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index e84873ed..4acd4092 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ xformers >= 0.0.19 fastapi uvicorn pydantic # Required for OpenAI server. +fschat # Required for OpenAI ChatCompletion Endpoint. diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 75bf07e9..2d1dcab3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -36,6 +36,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds logger = init_logger(__name__) served_model = None +chat_template = None app = fastapi.FastAPI() @@ -62,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]: async def get_gen_prompt(request) -> str: - conv = get_conv_template(request.model) + conv = get_conv_template(chat_template) conv = Conversation( name=conv.name, system=conv.system, @@ -553,13 +554,20 @@ if __name__ == "__main__": type=json.loads, default=["*"], help="allowed headers") + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. If not " + "specified, the model name will be the same as " + "the huggingface name.") parser.add_argument( - "--served-model-name", + "--chat-template", type=str, default=None, - help="The model name used in the API. If not specified, " - "the model name will be the same as the " - "huggingface name.") + help="The chat template name used in the ChatCompletion endpoint. If " + "not specified, we use the API model name as the template name. See " + "https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py " + "for the list of available templates.") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -573,7 +581,15 @@ if __name__ == "__main__": logger.info(f"args: {args}") - served_model = args.served_model_name or args.model + if args.served_model_name is not None: + served_model = args.served_model_name + else: + served_model = args.model + + if args.chat_template is not None: + chat_template = args.chat_template + else: + chat_template = served_model engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args)