[BugFix][Frontend] Fix completion logprobs=0 error (#3731)

This commit is contained in:
Roy 2024-03-30 00:38:21 +08:00 committed by GitHub
parent 6110c39dc8
commit f510395bbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 26 additions and 7 deletions

View File

@ -199,6 +199,27 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=0,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is None
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",

View File

@ -330,7 +330,7 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
class CompletionResponseChoice(BaseModel):

View File

@ -251,9 +251,6 @@ class OpenAIServingCompletion(OpenAIServing):
i]:] if output.logprobs else None
if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,

View File

@ -534,7 +534,8 @@ def _get_logprobs(
# Prepare query indices
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
largest_num_logprobs = 0
# at least get one logprob for each token
largest_num_logprobs = 1
sample_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(sampling_metadata.seq_groups, sample_results)):
@ -643,7 +644,7 @@ def _get_logprobs(
batched_ranks_query_result[query_result_idx].item())
}
query_result_idx += 1
if num_logprobs > 0:
if num_logprobs >= 0:
sample_logprobs_dict.update(
zip(
top_token_ids[sample_idx +

View File

@ -111,7 +111,7 @@ class RequestOutput:
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs
include_logprobs = seq_group.sampling_params.logprobs is not None
outputs = [
CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(),