From 66ded030677c7a0ca696f8d64e41637f4a358c00 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 18 Apr 2024 08:16:26 +0100 Subject: [PATCH] Allow model to be served under multiple names (#2894) Co-authored-by: Alexandre Payot --- vllm/entrypoints/openai/api_server.py | 8 ++++---- vllm/entrypoints/openai/cli_args.py | 10 +++++++--- vllm/entrypoints/openai/serving_chat.py | 8 ++++---- vllm/entrypoints/openai/serving_completion.py | 6 +++--- vllm/entrypoints/openai/serving_engine.py | 15 ++++++++------- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 32282bfd..d6673976 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -150,18 +150,18 @@ if __name__ == "__main__": logger.info(f"args: {args}") if args.served_model_name is not None: - served_model = args.served_model_name + served_model_names = args.served_model_name else: - served_model = args.model + served_model_names = [args.model] engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - openai_serving_chat = OpenAIServingChat(engine, served_model, + openai_serving_chat = OpenAIServingChat(engine, served_model_names, args.response_role, args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, served_model, args.lora_modules) + engine, served_model_names, args.lora_modules) app.root_path = args.root_path uvicorn.run(app, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index cc71931b..5c361b4d 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -54,11 +54,15 @@ def make_arg_parser(): help="If provided, the server will require this key " "to be presented in the header.") parser.add_argument("--served-model-name", + nargs="+", type=str, default=None, - help="The model name used in the API. If not " - "specified, the model name will be the same as " - "the huggingface name.") + help="The model name(s) used in the API. If multiple " + "names are provided, the server will respond to any " + "of the provided names. The model name in the model " + "field of a response will be the first name in this " + "list. If not specified, the model name will be the " + "same as the `--model` argument.") parser.add_argument( "--lora-modules", type=str, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c9ed4a9d..f35eab15 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -24,12 +24,12 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, - served_model: str, + served_model_names: List[str], response_role: str, lora_modules: Optional[List[LoRA]] = None, chat_template=None): super().__init__(engine=engine, - served_model=served_model, + served_model_names=served_model_names, lora_modules=lora_modules) self.response_role = response_role self._load_chat_template(chat_template) @@ -109,7 +109,7 @@ class OpenAIServingChat(OpenAIServing): result_generator: AsyncIterator[RequestOutput], request_id: str ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: - model_name = request.model + model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" first_iteration = True @@ -251,7 +251,7 @@ class OpenAIServingChat(OpenAIServing): result_generator: AsyncIterator[RequestOutput], request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: - model_name = request.model + model_name = self.served_model_names[0] created_time = int(time.time()) final_res: RequestOutput = None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a71f2d6a..b7e2530a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -53,10 +53,10 @@ class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, - served_model: str, + served_model_names: List[str], lora_modules: Optional[List[LoRA]] = None): super().__init__(engine=engine, - served_model=served_model, + served_model_names=served_model_names, lora_modules=lora_modules) async def create_completion(self, request: CompletionRequest, @@ -79,7 +79,7 @@ class OpenAIServingCompletion(OpenAIServing): return self.create_error_response( "suffix is not currently supported") - model_name = request.model + model_name = self.served_model_names[0] request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 77a568b5..b5a7a977 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -29,10 +29,10 @@ class OpenAIServing: def __init__(self, engine: AsyncLLMEngine, - served_model: str, + served_model_names: List[str], lora_modules=Optional[List[LoRA]]): self.engine = engine - self.served_model = served_model + self.served_model_names = served_model_names if lora_modules is None: self.lora_requests = [] else: @@ -74,13 +74,14 @@ class OpenAIServing: async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ - ModelCard(id=self.served_model, - root=self.served_model, + ModelCard(id=served_model_name, + root=self.served_model_names[0], permission=[ModelPermission()]) + for served_model_name in self.served_model_names ] lora_cards = [ ModelCard(id=lora.lora_name, - root=self.served_model, + root=self.served_model_names[0], permission=[ModelPermission()]) for lora in self.lora_requests ] @@ -150,7 +151,7 @@ class OpenAIServing: return json_str async def _check_model(self, request) -> Optional[ErrorResponse]: - if request.model == self.served_model: + if request.model in self.served_model_names: return if request.model in [lora.lora_name for lora in self.lora_requests]: return @@ -160,7 +161,7 @@ class OpenAIServing: status_code=HTTPStatus.NOT_FOUND) def _maybe_get_lora(self, request) -> Optional[LoRARequest]: - if request.model == self.served_model: + if request.model in self.served_model_names: return for lora in self.lora_requests: if request.model == lora.lora_name: