From 95e7d4a97cd64f8c6dc226ec0bbceebef6458701 Mon Sep 17 00:00:00 2001 From: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:15:50 -0700 Subject: [PATCH] Fix echo/logprob OpenAI completion bug (#3441) Co-authored-by: Dylan Hawk --- tests/entrypoints/test_openai_server.py | 31 ++++++++++++ vllm/entrypoints/openai/serving_chat.py | 9 ++-- vllm/entrypoints/openai/serving_completion.py | 15 ++++-- vllm/entrypoints/openai/serving_engine.py | 47 +++++++++++-------- 4 files changed, 73 insertions(+), 29 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 6f2086c4..7940430b 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -742,5 +742,36 @@ number: "1" | "2" assert content.strip() == ground_truth +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, + model_name: str): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=1) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert (completion.choices[0].text is not None + and re.search(r"^" + prompt_text, completion.choices[0].text)) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + assert len(logprobs.tokens) > 5 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0980c3d3..a03c5dc8 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -63,8 +63,9 @@ class OpenAIServingChat(OpenAIServing): request_id = f"cmpl-{random_uuid()}" try: - token_ids = self._validate_prompt_and_tokenize(request, - prompt=prompt) + # Tokenize/detokenize depending on prompt format (string/token list) + prompt_ids, prompt_text = self._validate_prompt_and_tokenize( + request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) guided_decode_logits_processor = ( @@ -78,8 +79,8 @@ class OpenAIServingChat(OpenAIServing): except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt, sampling_params, - request_id, token_ids, + result_generator = self.engine.generate(prompt_text, sampling_params, + request_id, prompt_ids, lora_request) # Streaming response if request.stream: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 06e7a922..c1f1744a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -136,23 +136,24 @@ class OpenAIServingCompletion(OpenAIServing): for i, prompt in enumerate(prompts): if prompt_is_tokens: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt_ids=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) else: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) + prompt_ids, prompt_text = prompt_formats generators.append( - self.engine.generate(prompt, + self.engine.generate(prompt_text, sampling_params, f"{request_id}-{i}", - prompt_token_ids=input_ids, + prompt_token_ids=prompt_ids, lora_request=lora_request)) except ValueError as e: # TODO: Use a vllm-specific Validation Error @@ -326,7 +327,8 @@ class OpenAIServingCompletion(OpenAIServing): output_text = prompt_text elif request.echo and request.max_tokens > 0: token_ids = prompt_token_ids + output.token_ids - top_logprobs = prompt_logprobs + output.logprobs + top_logprobs = (prompt_logprobs + output.logprobs + if request.logprobs else None) output_text = prompt_text + output.text else: token_ids = output.token_ids @@ -334,6 +336,9 @@ class OpenAIServingCompletion(OpenAIServing): output_text = output.text if request.logprobs is not None: + assert top_logprobs is not None, ( + "top_logprobs must be provided when logprobs " + "is requested") logprobs = self._create_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8f69388c..77a568b5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,7 @@ import asyncio import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from pydantic import conint @@ -99,27 +99,32 @@ class OpenAIServing: last_token_len = 0 if num_output_top_logprobs: logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] - if step_top_logprobs is not None: - token_logprob = step_top_logprobs[token_id].logprob + if step_top_logprobs is None: + token = self.tokenizer.decode(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(None) + logprobs.top_logprobs.append(None) else: - token_logprob = None - token = step_top_logprobs[token_id].decoded_token - logprobs.tokens.append(token) - logprobs.token_logprobs.append(token_logprob) + token_logprob = step_top_logprobs[token_id].logprob + token = step_top_logprobs[token_id].decoded_token + logprobs.tokens.append(token) + logprobs.token_logprobs.append(token_logprob) + + if num_output_top_logprobs: + logprobs.top_logprobs.append({ + p.decoded_token: p.logprob + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) + if len(logprobs.text_offset) == 0: logprobs.text_offset.append(initial_text_offset) else: logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) last_token_len = len(token) - - if num_output_top_logprobs: - logprobs.top_logprobs.append({ - p.decoded_token: p.logprob - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) return logprobs def create_error_response( @@ -164,12 +169,12 @@ class OpenAIServing: raise ValueError("The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( - self, - request: Union[ChatCompletionRequest, CompletionRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None - ) -> List[int]: + self, + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None, + truncate_prompt_tokens: Optional[conint(ge=1)] = None + ) -> Tuple[List[int], str]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") if (prompt and prompt_ids): @@ -187,6 +192,8 @@ class OpenAIServing: else: input_ids = prompt_ids + input_text = prompt if prompt is not None else self.tokenizer.decode( + prompt_ids) token_num = len(input_ids) if request.max_tokens is None: @@ -201,4 +208,4 @@ class OpenAIServing: f"{request.max_tokens} in the completion). " f"Please reduce the length of the messages or completion.", ) else: - return input_ids + return input_ids, input_text