[Feature]: Add OpenAI server prompt_logprobs support #6508 (#7453)

This commit is contained in:
Grant Pinkert 2024-08-16 12:38:08 +10:00 committed by GitHub
parent b67ae00cdb
commit f878c8feb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 154 additions and 3 deletions

View File

@ -3,7 +3,7 @@ import json
import re import re
import shutil import shutil
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import List from typing import Dict, List
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
@ -130,6 +130,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
temperature=0.0, temperature=0.0,
) )
assert len(completion.choices[0].text) >= 1 assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None
@pytest.mark.asyncio @pytest.mark.asyncio
@ -267,6 +268,128 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
assert len(completion.choices[0].text) >= 0 assert len(completion.choices[0].text) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str, prompt_logprobs: int):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.chat.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}
completion_1 = await client.chat.completions.create(**params)
params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)
assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0),
(MODEL_NAME, 1),
(MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: int):
params: Dict = {
"prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name,
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0
assert completion.choices[1].prompt_logprobs is not None
assert len(completion.choices[1].prompt_logprobs) > 0
else:
assert completion.choices[0].prompt_logprobs is None
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",

View File

@ -13,6 +13,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.utils import random_uuid from vllm.utils import random_uuid
# torch is mocked during docs generation, # torch is mocked during docs generation,
@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
prompt_logprobs: Optional[int] = None
# doc: end-chat-completion-sampling-params # doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params # doc: begin-chat-completion-extra-params
@ -263,7 +265,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None, logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None, prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
(self.top_logprobs if self.echo else None),
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens: bool = True spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None allowed_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[int] = None
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
# doc: begin-completion-extra-params # doc: begin-completion-extra-params
@ -454,7 +458,8 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping, early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None, prompt_logprobs=self.prompt_logprobs
if self.prompt_logprobs else self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
"to stop, None if the completion finished for some other reason " "to stop, None if the completion finished for some other reason "
"including encountering the EOS token"), "including encountering the EOS token"),
) )
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
class CompletionResponse(OpenAIBaseModel): class CompletionResponse(OpenAIBaseModel):
@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: List[ChatCompletionResponseChoice]
usage: UsageInfo usage: UsageInfo
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
class DeltaMessage(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel):

View File

@ -83,6 +83,16 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
if request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid "
f"negative value: {request.prompt_logprobs}")
try: try:
( (
lora_request, lora_request,
@ -506,6 +516,7 @@ class OpenAIServingChat(OpenAIServing):
model=model_name, model=model_name,
choices=choices, choices=choices,
usage=usage, usage=usage,
prompt_logprobs=final_res.prompt_logprobs,
) )
return response return response

View File

@ -84,6 +84,15 @@ class OpenAIServingCompletion(OpenAIServing):
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time()) created_time = int(time.time())
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
elif request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid negative "
f"value: {request.prompt_logprobs}")
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = [] generators: List[AsyncGenerator[RequestOutput, None]] = []
try: try:
@ -377,6 +386,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason, stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
) )
choices.append(choice_data) choices.append(choice_data)