[BugFix][Frontend] Fix completion logprobs=0 error (#3731)
This commit is contained in:
parent
6110c39dc8
commit
f510395bbf
@ -199,6 +199,27 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
|
|||||||
completion.choices[0].text) >= 5
|
completion.choices[0].text) >= 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
# first test base model, then test loras
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||||
|
)
|
||||||
|
async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=0,
|
||||||
|
)
|
||||||
|
choice = completion.choices[0]
|
||||||
|
assert choice.logprobs is not None
|
||||||
|
assert choice.logprobs.token_logprobs is not None
|
||||||
|
assert choice.logprobs.top_logprobs is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# just test 1 lora hereafter
|
# just test 1 lora hereafter
|
||||||
"model_name",
|
"model_name",
|
||||||
|
|||||||
@ -330,7 +330,7 @@ class LogProbs(BaseModel):
|
|||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
tokens: List[str] = Field(default_factory=list)
|
tokens: List[str] = Field(default_factory=list)
|
||||||
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
|
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseChoice(BaseModel):
|
class CompletionResponseChoice(BaseModel):
|
||||||
|
|||||||
@ -251,9 +251,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
i]:] if output.logprobs else None
|
i]:] if output.logprobs else None
|
||||||
|
|
||||||
if request.logprobs is not None:
|
if request.logprobs is not None:
|
||||||
assert top_logprobs is not None, (
|
|
||||||
"top_logprobs must be provided when logprobs "
|
|
||||||
"is requested")
|
|
||||||
logprobs = self._create_logprobs(
|
logprobs = self._create_logprobs(
|
||||||
token_ids=delta_token_ids,
|
token_ids=delta_token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
|
|||||||
@ -534,7 +534,8 @@ def _get_logprobs(
|
|||||||
# Prepare query indices
|
# Prepare query indices
|
||||||
batched_logprobs_query_seq_indices: List[int] = []
|
batched_logprobs_query_seq_indices: List[int] = []
|
||||||
batched_logprobs_query_token_indices: List[int] = []
|
batched_logprobs_query_token_indices: List[int] = []
|
||||||
largest_num_logprobs = 0
|
# at least get one logprob for each token
|
||||||
|
largest_num_logprobs = 1
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
for i, (seq_group, sample_result) in enumerate(
|
for i, (seq_group, sample_result) in enumerate(
|
||||||
zip(sampling_metadata.seq_groups, sample_results)):
|
zip(sampling_metadata.seq_groups, sample_results)):
|
||||||
@ -643,7 +644,7 @@ def _get_logprobs(
|
|||||||
batched_ranks_query_result[query_result_idx].item())
|
batched_ranks_query_result[query_result_idx].item())
|
||||||
}
|
}
|
||||||
query_result_idx += 1
|
query_result_idx += 1
|
||||||
if num_logprobs > 0:
|
if num_logprobs >= 0:
|
||||||
sample_logprobs_dict.update(
|
sample_logprobs_dict.update(
|
||||||
zip(
|
zip(
|
||||||
top_token_ids[sample_idx +
|
top_token_ids[sample_idx +
|
||||||
|
|||||||
@ -111,7 +111,7 @@ class RequestOutput:
|
|||||||
# NOTE: We need omit logprobs here explicitly because the sequence
|
# NOTE: We need omit logprobs here explicitly because the sequence
|
||||||
# always has the logprobs of the sampled tokens even if the
|
# always has the logprobs of the sampled tokens even if the
|
||||||
# logprobs are not requested.
|
# logprobs are not requested.
|
||||||
include_logprobs = seq_group.sampling_params.logprobs
|
include_logprobs = seq_group.sampling_params.logprobs is not None
|
||||||
outputs = [
|
outputs = [
|
||||||
CompletionOutput(seqs.index(seq), seq.output_text,
|
CompletionOutput(seqs.index(seq), seq.output_text,
|
||||||
seq.get_output_token_ids(),
|
seq.get_output_token_ids(),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user