From 7c3604fb68031da36567151a9bdfe69e04de44b8 Mon Sep 17 00:00:00 2001 From: Itay Etelis <92247226+Etelis@users.noreply.github.com> Date: Thu, 30 May 2024 02:13:22 +0300 Subject: [PATCH] [Bugfix] logprobs is not compatible with the OpenAI spec #4795 (#5031) --- vllm/entrypoints/openai/protocol.py | 5 ++--- vllm/entrypoints/openai/serving_chat.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 41e2f77f..e6eae689 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -109,7 +109,7 @@ class ChatCompletionRequest(OpenAIBaseModel): frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None logprobs: Optional[bool] = False - top_logprobs: Optional[int] = None + top_logprobs: Optional[int] = 0 max_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 @@ -192,8 +192,7 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params def to_sampling_params(self) -> SamplingParams: - if self.logprobs and not self.top_logprobs: - raise ValueError("Top logprobs must be set when logprobs is.") + # We now allow logprobs being true without top_logrobs. logits_processors = None if self.logit_bias: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 33daabd8..8cb50e33 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -286,7 +286,7 @@ class OpenAIServingChat(OpenAIServing): logprobs = self._create_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, + num_output_top_logprobs=request.top_logprobs, initial_text_offset=len(previous_texts[i]), ) else: @@ -373,7 +373,7 @@ class OpenAIServingChat(OpenAIServing): logprobs = self._create_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, + num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None