[BugFix][Frontend] Fix completion logprobs=0 error (#3731)
This commit is contained in:
parent
6110c39dc8
commit
f510395bbf
@ -199,6 +199,27 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
|
||||
completion.choices[0].text) >= 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=0,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
|
||||
@ -330,7 +330,7 @@ class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
|
||||
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
|
||||
@ -251,9 +251,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
i]:] if output.logprobs else None
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert top_logprobs is not None, (
|
||||
"top_logprobs must be provided when logprobs "
|
||||
"is requested")
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
|
||||
@ -534,7 +534,8 @@ def _get_logprobs(
|
||||
# Prepare query indices
|
||||
batched_logprobs_query_seq_indices: List[int] = []
|
||||
batched_logprobs_query_token_indices: List[int] = []
|
||||
largest_num_logprobs = 0
|
||||
# at least get one logprob for each token
|
||||
largest_num_logprobs = 1
|
||||
sample_idx = 0
|
||||
for i, (seq_group, sample_result) in enumerate(
|
||||
zip(sampling_metadata.seq_groups, sample_results)):
|
||||
@ -643,7 +644,7 @@ def _get_logprobs(
|
||||
batched_ranks_query_result[query_result_idx].item())
|
||||
}
|
||||
query_result_idx += 1
|
||||
if num_logprobs > 0:
|
||||
if num_logprobs >= 0:
|
||||
sample_logprobs_dict.update(
|
||||
zip(
|
||||
top_token_ids[sample_idx +
|
||||
|
||||
@ -111,7 +111,7 @@ class RequestOutput:
|
||||
# NOTE: We need omit logprobs here explicitly because the sequence
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
include_logprobs = seq_group.sampling_params.logprobs
|
||||
include_logprobs = seq_group.sampling_params.logprobs is not None
|
||||
outputs = [
|
||||
CompletionOutput(seqs.index(seq), seq.output_text,
|
||||
seq.get_output_token_ids(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user