parent
b67ae00cdb
commit
f878c8feb0
@ -3,7 +3,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
@ -130,6 +130,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
assert len(completion.choices[0].text) >= 1
|
assert len(completion.choices[0].text) >= 1
|
||||||
|
assert completion.choices[0].prompt_logprobs is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -267,6 +268,128 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
|||||||
assert len(completion.choices[0].text) >= 0
|
assert len(completion.choices[0].text) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name, prompt_logprobs",
|
||||||
|
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
|
||||||
|
)
|
||||||
|
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str, prompt_logprobs: int):
|
||||||
|
params: Dict = {
|
||||||
|
"messages": [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Who won the world series in 2020?"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"assistant",
|
||||||
|
"content":
|
||||||
|
"The Los Angeles Dodgers won the World Series in 2020."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Where was it played?"
|
||||||
|
}],
|
||||||
|
"model":
|
||||||
|
model_name
|
||||||
|
}
|
||||||
|
|
||||||
|
if prompt_logprobs is not None:
|
||||||
|
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||||
|
|
||||||
|
if prompt_logprobs and prompt_logprobs < 0:
|
||||||
|
with pytest.raises(BadRequestError) as err_info:
|
||||||
|
await client.chat.completions.create(**params)
|
||||||
|
expected_err_string = (
|
||||||
|
"Error code: 400 - {'object': 'error', 'message': "
|
||||||
|
"'Prompt_logprobs set to invalid negative value: -1',"
|
||||||
|
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
|
||||||
|
assert str(err_info.value) == expected_err_string
|
||||||
|
else:
|
||||||
|
completion = await client.chat.completions.create(**params)
|
||||||
|
if prompt_logprobs and prompt_logprobs > 0:
|
||||||
|
assert completion.prompt_logprobs is not None
|
||||||
|
assert len(completion.prompt_logprobs) > 0
|
||||||
|
else:
|
||||||
|
assert completion.prompt_logprobs is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
params: Dict = {
|
||||||
|
"messages": [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Who won the world series in 2020?"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"assistant",
|
||||||
|
"content":
|
||||||
|
"The Los Angeles Dodgers won the World Series in 2020."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Where was it played?"
|
||||||
|
}],
|
||||||
|
"model":
|
||||||
|
model_name,
|
||||||
|
"extra_body": {
|
||||||
|
"prompt_logprobs": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
completion_1 = await client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
params["extra_body"] = {"prompt_logprobs": 2}
|
||||||
|
completion_2 = await client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
assert len(completion_1.prompt_logprobs[3]) == 1
|
||||||
|
assert len(completion_2.prompt_logprobs[3]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
|
||||||
|
(MODEL_NAME, 0),
|
||||||
|
(MODEL_NAME, 1),
|
||||||
|
(MODEL_NAME, None)])
|
||||||
|
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str,
|
||||||
|
prompt_logprobs: int):
|
||||||
|
params: Dict = {
|
||||||
|
"prompt": ["A robot may not injure another robot", "My name is"],
|
||||||
|
"model": model_name,
|
||||||
|
}
|
||||||
|
if prompt_logprobs is not None:
|
||||||
|
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||||
|
|
||||||
|
if prompt_logprobs and prompt_logprobs < 0:
|
||||||
|
with pytest.raises(BadRequestError) as err_info:
|
||||||
|
await client.completions.create(**params)
|
||||||
|
expected_err_string = (
|
||||||
|
"Error code: 400 - {'object': 'error', 'message': "
|
||||||
|
"'Prompt_logprobs set to invalid negative value: -1',"
|
||||||
|
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
|
||||||
|
assert str(err_info.value) == expected_err_string
|
||||||
|
else:
|
||||||
|
completion = await client.completions.create(**params)
|
||||||
|
if prompt_logprobs and prompt_logprobs > 0:
|
||||||
|
assert completion.choices[0].prompt_logprobs is not None
|
||||||
|
assert len(completion.choices[0].prompt_logprobs) > 0
|
||||||
|
|
||||||
|
assert completion.choices[1].prompt_logprobs is not None
|
||||||
|
assert len(completion.choices[1].prompt_logprobs) > 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert completion.choices[0].prompt_logprobs is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
|||||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||||
|
from vllm.sequence import Logprob
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
# torch is mocked during docs generation,
|
# torch is mocked during docs generation,
|
||||||
@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
spaces_between_special_tokens: bool = True
|
spaces_between_special_tokens: bool = True
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
prompt_logprobs: Optional[int] = None
|
||||||
# doc: end-chat-completion-sampling-params
|
# doc: end-chat-completion-sampling-params
|
||||||
|
|
||||||
# doc: begin-chat-completion-extra-params
|
# doc: begin-chat-completion-extra-params
|
||||||
@ -263,7 +265,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
stop_token_ids=self.stop_token_ids,
|
stop_token_ids=self.stop_token_ids,
|
||||||
logprobs=self.top_logprobs if self.logprobs else None,
|
logprobs=self.top_logprobs if self.logprobs else None,
|
||||||
prompt_logprobs=self.top_logprobs if self.echo else None,
|
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
|
||||||
|
(self.top_logprobs if self.echo else None),
|
||||||
ignore_eos=self.ignore_eos,
|
ignore_eos=self.ignore_eos,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
min_tokens=self.min_tokens,
|
min_tokens=self.min_tokens,
|
||||||
@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
spaces_between_special_tokens: bool = True
|
spaces_between_special_tokens: bool = True
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
allowed_token_ids: Optional[List[int]] = None
|
allowed_token_ids: Optional[List[int]] = None
|
||||||
|
prompt_logprobs: Optional[int] = None
|
||||||
# doc: end-completion-sampling-params
|
# doc: end-completion-sampling-params
|
||||||
|
|
||||||
# doc: begin-completion-extra-params
|
# doc: begin-completion-extra-params
|
||||||
@ -454,7 +458,8 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
min_tokens=self.min_tokens,
|
min_tokens=self.min_tokens,
|
||||||
use_beam_search=self.use_beam_search,
|
use_beam_search=self.use_beam_search,
|
||||||
early_stopping=self.early_stopping,
|
early_stopping=self.early_stopping,
|
||||||
prompt_logprobs=self.logprobs if self.echo else None,
|
prompt_logprobs=self.prompt_logprobs
|
||||||
|
if self.prompt_logprobs else self.logprobs if self.echo else None,
|
||||||
skip_special_tokens=self.skip_special_tokens,
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
|
|||||||
"to stop, None if the completion finished for some other reason "
|
"to stop, None if the completion finished for some other reason "
|
||||||
"including encountering the EOS token"),
|
"including encountering the EOS token"),
|
||||||
)
|
)
|
||||||
|
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponse(OpenAIBaseModel):
|
class CompletionResponse(OpenAIBaseModel):
|
||||||
@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
|||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseChoice]
|
choices: List[ChatCompletionResponseChoice]
|
||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(OpenAIBaseModel):
|
class DeltaMessage(OpenAIBaseModel):
|
||||||
|
|||||||
@ -83,6 +83,16 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
|
if request.prompt_logprobs is not None:
|
||||||
|
if request.stream and request.prompt_logprobs > 0:
|
||||||
|
return self.create_error_response(
|
||||||
|
"Prompt_logprobs are not available when stream is enabled")
|
||||||
|
|
||||||
|
if request.prompt_logprobs < 0:
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Prompt_logprobs set to invalid "
|
||||||
|
f"negative value: {request.prompt_logprobs}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
lora_request,
|
lora_request,
|
||||||
@ -506,6 +516,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
choices=choices,
|
choices=choices,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
|
prompt_logprobs=final_res.prompt_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@ -84,6 +84,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
|
if request.prompt_logprobs is not None:
|
||||||
|
if request.stream and request.prompt_logprobs > 0:
|
||||||
|
return self.create_error_response(
|
||||||
|
"Prompt_logprobs are not available when stream is enabled")
|
||||||
|
elif request.prompt_logprobs < 0:
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Prompt_logprobs set to invalid negative "
|
||||||
|
f"value: {request.prompt_logprobs}")
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
@ -377,6 +386,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
stop_reason=output.stop_reason,
|
stop_reason=output.stop_reason,
|
||||||
|
prompt_logprobs=final_res.prompt_logprobs,
|
||||||
)
|
)
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user