From 2c37540aa6af89f0ece874d831dff3bf62edf486 Mon Sep 17 00:00:00 2001 From: danieljannai21 <100521221+danieljannai21@users.noreply.github.com> Date: Tue, 2 Jul 2024 09:01:57 +0300 Subject: [PATCH] [Frontend] Add template related params to request (#5709) --- requirements-common.txt | 2 +- vllm/entrypoints/openai/protocol.py | 21 +++++++++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 8 ++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 636f8534..765568b0 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -6,7 +6,7 @@ numpy < 2.0.0 requests tqdm py-cpuinfo -transformers >= 4.42.0 # Required for Gemma 2. +transformers >= 4.42.0 # Required for Gemma 2 and for additional chat template parameters. tokenizers >= 0.19.1 # Required for Llama 3. fastapi aiohttp diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d1568cb3..7f97e534 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -190,6 +190,27 @@ class ChatCompletionRequest(OpenAIBaseModel): "special tokens so this should be set to False (as is the " "default)."), ) + documents: Optional[List[Dict[str, str]]] = Field( + default=None, + description= + ("A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + "\"title\" and \"text\" keys."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "If this is not passed, the model's default chat template will be " + "used instead."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) include_stop_str_in_output: Optional[bool] = Field( default=False, description=( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 744e1d94..4a960fd7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -218,10 +218,18 @@ class OpenAIServingChat(OpenAIServing): conversation.extend(chat_parsed_result.messages) image_futures.extend(chat_parsed_result.image_futures) + tool_dicts = None if request.tools is None else [ + tool.model_dump() for tool in request.tools + ] + prompt = self.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, add_generation_prompt=request.add_generation_prompt, + tools=tool_dicts, + documents=request.documents, + chat_template=request.chat_template, + **(request.chat_template_kwargs or {}), ) except Exception as e: logger.error("Error in applying chat template from request: %s", e)