parent
b67ae00cdb
commit
f878c8feb0
@ -3,7 +3,7 @@ import json
|
||||
import re
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
@ -130,6 +130,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -267,6 +268,128 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
||||
assert len(completion.choices[0].text) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, prompt_logprobs",
|
||||
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
|
||||
)
|
||||
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str, prompt_logprobs: int):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name
|
||||
}
|
||||
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError) as err_info:
|
||||
await client.chat.completions.create(**params)
|
||||
expected_err_string = (
|
||||
"Error code: 400 - {'object': 'error', 'message': "
|
||||
"'Prompt_logprobs set to invalid negative value: -1',"
|
||||
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
|
||||
assert str(err_info.value) == expected_err_string
|
||||
else:
|
||||
completion = await client.chat.completions.create(**params)
|
||||
if prompt_logprobs and prompt_logprobs > 0:
|
||||
assert completion.prompt_logprobs is not None
|
||||
assert len(completion.prompt_logprobs) > 0
|
||||
else:
|
||||
assert completion.prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name,
|
||||
"extra_body": {
|
||||
"prompt_logprobs": 1
|
||||
}
|
||||
}
|
||||
|
||||
completion_1 = await client.chat.completions.create(**params)
|
||||
|
||||
params["extra_body"] = {"prompt_logprobs": 2}
|
||||
completion_2 = await client.chat.completions.create(**params)
|
||||
|
||||
assert len(completion_1.prompt_logprobs[3]) == 1
|
||||
assert len(completion_2.prompt_logprobs[3]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
|
||||
(MODEL_NAME, 0),
|
||||
(MODEL_NAME, 1),
|
||||
(MODEL_NAME, None)])
|
||||
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
prompt_logprobs: int):
|
||||
params: Dict = {
|
||||
"prompt": ["A robot may not injure another robot", "My name is"],
|
||||
"model": model_name,
|
||||
}
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError) as err_info:
|
||||
await client.completions.create(**params)
|
||||
expected_err_string = (
|
||||
"Error code: 400 - {'object': 'error', 'message': "
|
||||
"'Prompt_logprobs set to invalid negative value: -1',"
|
||||
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
|
||||
assert str(err_info.value) == expected_err_string
|
||||
else:
|
||||
completion = await client.completions.create(**params)
|
||||
if prompt_logprobs and prompt_logprobs > 0:
|
||||
assert completion.choices[0].prompt_logprobs is not None
|
||||
assert len(completion.choices[0].prompt_logprobs) > 0
|
||||
|
||||
assert completion.choices[1].prompt_logprobs is not None
|
||||
assert len(completion.choices[1].prompt_logprobs) > 0
|
||||
|
||||
else:
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# torch is mocked during docs generation,
|
||||
@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
# doc: begin-chat-completion-extra-params
|
||||
@ -263,7 +265,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=self.top_logprobs if self.echo else None,
|
||||
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
|
||||
(self.top_logprobs if self.echo else None),
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
allowed_token_ids: Optional[List[int]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: begin-completion-extra-params
|
||||
@ -454,7 +458,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
min_tokens=self.min_tokens,
|
||||
use_beam_search=self.use_beam_search,
|
||||
early_stopping=self.early_stopping,
|
||||
prompt_logprobs=self.logprobs if self.echo else None,
|
||||
prompt_logprobs=self.prompt_logprobs
|
||||
if self.prompt_logprobs else self.logprobs if self.echo else None,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"),
|
||||
)
|
||||
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||
|
||||
|
||||
class CompletionResponse(OpenAIBaseModel):
|
||||
@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||
|
||||
|
||||
class DeltaMessage(OpenAIBaseModel):
|
||||
|
||||
@ -83,6 +83,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.prompt_logprobs is not None:
|
||||
if request.stream and request.prompt_logprobs > 0:
|
||||
return self.create_error_response(
|
||||
"Prompt_logprobs are not available when stream is enabled")
|
||||
|
||||
if request.prompt_logprobs < 0:
|
||||
return self.create_error_response(
|
||||
f"Prompt_logprobs set to invalid "
|
||||
f"negative value: {request.prompt_logprobs}")
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@ -506,6 +516,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -84,6 +84,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
if request.prompt_logprobs is not None:
|
||||
if request.stream and request.prompt_logprobs > 0:
|
||||
return self.create_error_response(
|
||||
"Prompt_logprobs are not available when stream is enabled")
|
||||
elif request.prompt_logprobs < 0:
|
||||
return self.create_error_response(
|
||||
f"Prompt_logprobs set to invalid negative "
|
||||
f"value: {request.prompt_logprobs}")
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
@ -377,6 +386,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user