Fix echo/logprob OpenAI completion bug (#3441)
Co-authored-by: Dylan Hawk <dylanwawk@gmail.com>
This commit is contained in:
parent
559eb852f8
commit
95e7d4a97c
@ -742,5 +742,36 @@ number: "1" | "2"
|
|||||||
assert content.strip() == ground_truth
|
assert content.strip() == ground_truth
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
# first test base model, then test loras
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||||
|
)
|
||||||
|
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||||
|
# test using text and token IDs
|
||||||
|
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||||
|
completion = await client.completions.create(model=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
echo=True,
|
||||||
|
logprobs=1)
|
||||||
|
|
||||||
|
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||||
|
list) else prompt
|
||||||
|
assert (completion.choices[0].text is not None
|
||||||
|
and re.search(r"^" + prompt_text, completion.choices[0].text))
|
||||||
|
logprobs = completion.choices[0].logprobs
|
||||||
|
assert logprobs is not None
|
||||||
|
assert len(logprobs.text_offset) > 5
|
||||||
|
assert (len(logprobs.token_logprobs) > 5
|
||||||
|
and logprobs.token_logprobs[0] is None)
|
||||||
|
assert (len(logprobs.top_logprobs) > 5
|
||||||
|
and logprobs.top_logprobs[0] is None)
|
||||||
|
assert len(logprobs.tokens) > 5
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
@ -63,8 +63,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
try:
|
try:
|
||||||
token_ids = self._validate_prompt_and_tokenize(request,
|
# Tokenize/detokenize depending on prompt format (string/token list)
|
||||||
prompt=prompt)
|
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
||||||
|
request, prompt=prompt)
|
||||||
sampling_params = request.to_sampling_params()
|
sampling_params = request.to_sampling_params()
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
guided_decode_logits_processor = (
|
guided_decode_logits_processor = (
|
||||||
@ -78,8 +79,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
result_generator = self.engine.generate(prompt, sampling_params,
|
result_generator = self.engine.generate(prompt_text, sampling_params,
|
||||||
request_id, token_ids,
|
request_id, prompt_ids,
|
||||||
lora_request)
|
lora_request)
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if request.stream:
|
if request.stream:
|
||||||
|
|||||||
@ -136,23 +136,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if prompt_is_tokens:
|
if prompt_is_tokens:
|
||||||
input_ids = self._validate_prompt_and_tokenize(
|
prompt_formats = self._validate_prompt_and_tokenize(
|
||||||
request,
|
request,
|
||||||
prompt_ids=prompt,
|
prompt_ids=prompt,
|
||||||
truncate_prompt_tokens=sampling_params.
|
truncate_prompt_tokens=sampling_params.
|
||||||
truncate_prompt_tokens)
|
truncate_prompt_tokens)
|
||||||
else:
|
else:
|
||||||
input_ids = self._validate_prompt_and_tokenize(
|
prompt_formats = self._validate_prompt_and_tokenize(
|
||||||
request,
|
request,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
truncate_prompt_tokens=sampling_params.
|
truncate_prompt_tokens=sampling_params.
|
||||||
truncate_prompt_tokens)
|
truncate_prompt_tokens)
|
||||||
|
prompt_ids, prompt_text = prompt_formats
|
||||||
|
|
||||||
generators.append(
|
generators.append(
|
||||||
self.engine.generate(prompt,
|
self.engine.generate(prompt_text,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
f"{request_id}-{i}",
|
f"{request_id}-{i}",
|
||||||
prompt_token_ids=input_ids,
|
prompt_token_ids=prompt_ids,
|
||||||
lora_request=lora_request))
|
lora_request=lora_request))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
@ -326,7 +327,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
output_text = prompt_text
|
output_text = prompt_text
|
||||||
elif request.echo and request.max_tokens > 0:
|
elif request.echo and request.max_tokens > 0:
|
||||||
token_ids = prompt_token_ids + output.token_ids
|
token_ids = prompt_token_ids + output.token_ids
|
||||||
top_logprobs = prompt_logprobs + output.logprobs
|
top_logprobs = (prompt_logprobs + output.logprobs
|
||||||
|
if request.logprobs else None)
|
||||||
output_text = prompt_text + output.text
|
output_text = prompt_text + output.text
|
||||||
else:
|
else:
|
||||||
token_ids = output.token_ids
|
token_ids = output.token_ids
|
||||||
@ -334,6 +336,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
output_text = output.text
|
output_text = output.text
|
||||||
|
|
||||||
if request.logprobs is not None:
|
if request.logprobs is not None:
|
||||||
|
assert top_logprobs is not None, (
|
||||||
|
"top_logprobs must be provided when logprobs "
|
||||||
|
"is requested")
|
||||||
logprobs = self._create_logprobs(
|
logprobs = self._create_logprobs(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import conint
|
from pydantic import conint
|
||||||
|
|
||||||
@ -99,27 +99,32 @@ class OpenAIServing:
|
|||||||
last_token_len = 0
|
last_token_len = 0
|
||||||
if num_output_top_logprobs:
|
if num_output_top_logprobs:
|
||||||
logprobs.top_logprobs = []
|
logprobs.top_logprobs = []
|
||||||
|
|
||||||
for i, token_id in enumerate(token_ids):
|
for i, token_id in enumerate(token_ids):
|
||||||
step_top_logprobs = top_logprobs[i]
|
step_top_logprobs = top_logprobs[i]
|
||||||
if step_top_logprobs is not None:
|
if step_top_logprobs is None:
|
||||||
token_logprob = step_top_logprobs[token_id].logprob
|
token = self.tokenizer.decode(token_id)
|
||||||
|
logprobs.tokens.append(token)
|
||||||
|
logprobs.token_logprobs.append(None)
|
||||||
|
logprobs.top_logprobs.append(None)
|
||||||
else:
|
else:
|
||||||
token_logprob = None
|
token_logprob = step_top_logprobs[token_id].logprob
|
||||||
token = step_top_logprobs[token_id].decoded_token
|
token = step_top_logprobs[token_id].decoded_token
|
||||||
logprobs.tokens.append(token)
|
logprobs.tokens.append(token)
|
||||||
logprobs.token_logprobs.append(token_logprob)
|
logprobs.token_logprobs.append(token_logprob)
|
||||||
if len(logprobs.text_offset) == 0:
|
|
||||||
logprobs.text_offset.append(initial_text_offset)
|
|
||||||
else:
|
|
||||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
|
||||||
last_token_len)
|
|
||||||
last_token_len = len(token)
|
|
||||||
|
|
||||||
if num_output_top_logprobs:
|
if num_output_top_logprobs:
|
||||||
logprobs.top_logprobs.append({
|
logprobs.top_logprobs.append({
|
||||||
p.decoded_token: p.logprob
|
p.decoded_token: p.logprob
|
||||||
for i, p in step_top_logprobs.items()
|
for i, p in step_top_logprobs.items()
|
||||||
} if step_top_logprobs else None)
|
} if step_top_logprobs else None)
|
||||||
|
|
||||||
|
if len(logprobs.text_offset) == 0:
|
||||||
|
logprobs.text_offset.append(initial_text_offset)
|
||||||
|
else:
|
||||||
|
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||||
|
last_token_len)
|
||||||
|
last_token_len = len(token)
|
||||||
return logprobs
|
return logprobs
|
||||||
|
|
||||||
def create_error_response(
|
def create_error_response(
|
||||||
@ -169,7 +174,7 @@ class OpenAIServing:
|
|||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
prompt_ids: Optional[List[int]] = None,
|
prompt_ids: Optional[List[int]] = None,
|
||||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||||
) -> List[int]:
|
) -> Tuple[List[int], str]:
|
||||||
if not (prompt or prompt_ids):
|
if not (prompt or prompt_ids):
|
||||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||||
if (prompt and prompt_ids):
|
if (prompt and prompt_ids):
|
||||||
@ -187,6 +192,8 @@ class OpenAIServing:
|
|||||||
else:
|
else:
|
||||||
input_ids = prompt_ids
|
input_ids = prompt_ids
|
||||||
|
|
||||||
|
input_text = prompt if prompt is not None else self.tokenizer.decode(
|
||||||
|
prompt_ids)
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
if request.max_tokens is None:
|
if request.max_tokens is None:
|
||||||
@ -201,4 +208,4 @@ class OpenAIServing:
|
|||||||
f"{request.max_tokens} in the completion). "
|
f"{request.max_tokens} in the completion). "
|
||||||
f"Please reduce the length of the messages or completion.", )
|
f"Please reduce the length of the messages or completion.", )
|
||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids, input_text
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user