diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 05f66723..4d0c6d73 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -3,7 +3,7 @@ import json import re import shutil from tempfile import TemporaryDirectory -from typing import List +from typing import Dict, List import jsonschema import openai # use the official client for correctness check @@ -130,6 +130,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, temperature=0.0, ) assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None @pytest.mark.asyncio @@ -267,6 +268,128 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, assert len(completion.choices[0].text) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name, prompt_logprobs", + [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], +) +async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, + model_name: str, prompt_logprobs: int): + params: Dict = { + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Who won the world series in 2020?" + }, { + "role": + "assistant", + "content": + "The Los Angeles Dodgers won the World Series in 2020." + }, { + "role": "user", + "content": "Where was it played?" + }], + "model": + model_name + } + + if prompt_logprobs is not None: + params["extra_body"] = {"prompt_logprobs": prompt_logprobs} + + if prompt_logprobs and prompt_logprobs < 0: + with pytest.raises(BadRequestError) as err_info: + await client.chat.completions.create(**params) + expected_err_string = ( + "Error code: 400 - {'object': 'error', 'message': " + "'Prompt_logprobs set to invalid negative value: -1'," + " 'type': 'BadRequestError', 'param': None, 'code': 400}") + assert str(err_info.value) == expected_err_string + else: + completion = await client.chat.completions.create(**params) + if prompt_logprobs and prompt_logprobs > 0: + assert completion.prompt_logprobs is not None + assert len(completion.prompt_logprobs) > 0 + else: + assert completion.prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, + model_name: str): + params: Dict = { + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Who won the world series in 2020?" + }, { + "role": + "assistant", + "content": + "The Los Angeles Dodgers won the World Series in 2020." + }, { + "role": "user", + "content": "Where was it played?" + }], + "model": + model_name, + "extra_body": { + "prompt_logprobs": 1 + } + } + + completion_1 = await client.chat.completions.create(**params) + + params["extra_body"] = {"prompt_logprobs": 2} + completion_2 = await client.chat.completions.create(**params) + + assert len(completion_1.prompt_logprobs[3]) == 1 + assert len(completion_2.prompt_logprobs[3]) == 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), + (MODEL_NAME, 0), + (MODEL_NAME, 1), + (MODEL_NAME, None)]) +async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, + model_name: str, + prompt_logprobs: int): + params: Dict = { + "prompt": ["A robot may not injure another robot", "My name is"], + "model": model_name, + } + if prompt_logprobs is not None: + params["extra_body"] = {"prompt_logprobs": prompt_logprobs} + + if prompt_logprobs and prompt_logprobs < 0: + with pytest.raises(BadRequestError) as err_info: + await client.completions.create(**params) + expected_err_string = ( + "Error code: 400 - {'object': 'error', 'message': " + "'Prompt_logprobs set to invalid negative value: -1'," + " 'type': 'BadRequestError', 'param': None, 'code': 400}") + assert str(err_info.value) == expected_err_string + else: + completion = await client.completions.create(**params) + if prompt_logprobs and prompt_logprobs > 0: + assert completion.choices[0].prompt_logprobs is not None + assert len(completion.choices[0].prompt_logprobs) > 0 + + assert completion.choices[1].prompt_logprobs is not None + assert len(completion.choices[1].prompt_logprobs) > 0 + + else: + assert completion.choices[0].prompt_logprobs is None + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7da3002b..aef42e94 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -13,6 +13,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sequence import Logprob from vllm.utils import random_uuid # torch is mocked during docs generation, @@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel): skip_special_tokens: bool = True spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + prompt_logprobs: Optional[int] = None # doc: end-chat-completion-sampling-params # doc: begin-chat-completion-extra-params @@ -263,7 +265,8 @@ class ChatCompletionRequest(OpenAIBaseModel): stop=self.stop, stop_token_ids=self.stop_token_ids, logprobs=self.top_logprobs if self.logprobs else None, - prompt_logprobs=self.top_logprobs if self.echo else None, + prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else + (self.top_logprobs if self.echo else None), ignore_eos=self.ignore_eos, max_tokens=max_tokens, min_tokens=self.min_tokens, @@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel): spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None allowed_token_ids: Optional[List[int]] = None + prompt_logprobs: Optional[int] = None # doc: end-completion-sampling-params # doc: begin-completion-extra-params @@ -454,7 +458,8 @@ class CompletionRequest(OpenAIBaseModel): min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, - prompt_logprobs=self.logprobs if self.echo else None, + prompt_logprobs=self.prompt_logprobs + if self.prompt_logprobs else self.logprobs if self.echo else None, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, @@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) + prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None class CompletionResponse(OpenAIBaseModel): @@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel): model: str choices: List[ChatCompletionResponseChoice] usage: UsageInfo + prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None class DeltaMessage(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2167b967..08209d44 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -83,6 +83,16 @@ class OpenAIServingChat(OpenAIServing): if error_check_ret is not None: return error_check_ret + if request.prompt_logprobs is not None: + if request.stream and request.prompt_logprobs > 0: + return self.create_error_response( + "Prompt_logprobs are not available when stream is enabled") + + if request.prompt_logprobs < 0: + return self.create_error_response( + f"Prompt_logprobs set to invalid " + f"negative value: {request.prompt_logprobs}") + try: ( lora_request, @@ -506,6 +516,7 @@ class OpenAIServingChat(OpenAIServing): model=model_name, choices=choices, usage=usage, + prompt_logprobs=final_res.prompt_logprobs, ) return response diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index f4c91ce0..24206b59 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -84,6 +84,15 @@ class OpenAIServingCompletion(OpenAIServing): request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) + if request.prompt_logprobs is not None: + if request.stream and request.prompt_logprobs > 0: + return self.create_error_response( + "Prompt_logprobs are not available when stream is enabled") + elif request.prompt_logprobs < 0: + return self.create_error_response( + f"Prompt_logprobs set to invalid negative " + f"value: {request.prompt_logprobs}") + # Schedule the request and get the result generator. generators: List[AsyncGenerator[RequestOutput, None]] = [] try: @@ -377,6 +386,7 @@ class OpenAIServingCompletion(OpenAIServing): logprobs=logprobs, finish_reason=output.finish_reason, stop_reason=output.stop_reason, + prompt_logprobs=final_res.prompt_logprobs, ) choices.append(choice_data)