Fix echo/logprob OpenAI completion bug (#3441)

Co-authored-by: Dylan Hawk <dylanwawk@gmail.com>
This commit is contained in:
Dylan Hawk 2024-04-11 15:15:50 -07:00 committed by GitHub
parent 559eb852f8
commit 95e7d4a97c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 29 deletions

View File

@ -742,5 +742,36 @@ number: "1" | "2"
assert content.strip() == ground_truth
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
model_name: str):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=1)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
assert (completion.choices[0].text is not None
and re.search(r"^" + prompt_text, completion.choices[0].text))
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
assert len(logprobs.tokens) > 5
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -63,8 +63,9 @@ class OpenAIServingChat(OpenAIServing):
request_id = f"cmpl-{random_uuid()}"
try:
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
@ -78,8 +79,8 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e:
return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt, sampling_params,
request_id, token_ids,
result_generator = self.engine.generate(prompt_text, sampling_params,
request_id, prompt_ids,
lora_request)
# Streaming response
if request.stream:

View File

@ -136,23 +136,24 @@ class OpenAIServingCompletion(OpenAIServing):
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
input_ids = self._validate_prompt_and_tokenize(
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
generators.append(
self.engine.generate(prompt,
self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids,
prompt_token_ids=prompt_ids,
lora_request=lora_request))
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
@ -326,7 +327,8 @@ class OpenAIServingCompletion(OpenAIServing):
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
@ -334,6 +336,9 @@ class OpenAIServingCompletion(OpenAIServing):
output_text = output.text
if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,

View File

@ -2,7 +2,7 @@ import asyncio
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
from pydantic import conint
@ -99,27 +99,32 @@ class OpenAIServing:
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id].logprob
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
logprobs.top_logprobs.append(None)
else:
token_logprob = None
token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
return logprobs
def create_error_response(
@ -169,7 +174,7 @@ class OpenAIServing:
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> List[int]:
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
@ -187,6 +192,8 @@ class OpenAIServing:
else:
input_ids = prompt_ids
input_text = prompt if prompt is not None else self.tokenizer.decode(
prompt_ids)
token_num = len(input_ids)
if request.max_tokens is None:
@ -201,4 +208,4 @@ class OpenAIServing:
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
else:
return input_ids
return input_ids, input_text