Fix echo/logprob OpenAI completion bug (#3441)
Co-authored-by: Dylan Hawk <dylanwawk@gmail.com>
This commit is contained in:
parent
559eb852f8
commit
95e7d4a97c
@ -742,5 +742,36 @@ number: "1" | "2"
|
||||
assert content.strip() == ground_truth
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# test using text and token IDs
|
||||
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=True,
|
||||
logprobs=1)
|
||||
|
||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||
list) else prompt
|
||||
assert (completion.choices[0].text is not None
|
||||
and re.search(r"^" + prompt_text, completion.choices[0].text))
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) > 5
|
||||
assert (len(logprobs.token_logprobs) > 5
|
||||
and logprobs.token_logprobs[0] is None)
|
||||
assert (len(logprobs.top_logprobs) > 5
|
||||
and logprobs.top_logprobs[0] is None)
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -63,8 +63,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
try:
|
||||
token_ids = self._validate_prompt_and_tokenize(request,
|
||||
prompt=prompt)
|
||||
# Tokenize/detokenize depending on prompt format (string/token list)
|
||||
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
||||
request, prompt=prompt)
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
guided_decode_logits_processor = (
|
||||
@ -78,8 +79,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = self.engine.generate(prompt, sampling_params,
|
||||
request_id, token_ids,
|
||||
result_generator = self.engine.generate(prompt_text, sampling_params,
|
||||
request_id, prompt_ids,
|
||||
lora_request)
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
|
||||
@ -136,23 +136,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt_is_tokens:
|
||||
input_ids = self._validate_prompt_and_tokenize(
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
prompt_ids=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
else:
|
||||
input_ids = self._validate_prompt_and_tokenize(
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
prompt=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
|
||||
generators.append(
|
||||
self.engine.generate(prompt,
|
||||
self.engine.generate(prompt_text,
|
||||
sampling_params,
|
||||
f"{request_id}-{i}",
|
||||
prompt_token_ids=input_ids,
|
||||
prompt_token_ids=prompt_ids,
|
||||
lora_request=lora_request))
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
@ -326,7 +327,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
output_text = prompt_text
|
||||
elif request.echo and request.max_tokens > 0:
|
||||
token_ids = prompt_token_ids + output.token_ids
|
||||
top_logprobs = prompt_logprobs + output.logprobs
|
||||
top_logprobs = (prompt_logprobs + output.logprobs
|
||||
if request.logprobs else None)
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
@ -334,6 +336,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert top_logprobs is not None, (
|
||||
"top_logprobs must be provided when logprobs "
|
||||
"is requested")
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import conint
|
||||
|
||||
@ -99,27 +99,32 @@ class OpenAIServing:
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
if step_top_logprobs is None:
|
||||
token = self.tokenizer.decode(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(None)
|
||||
logprobs.top_logprobs.append(None)
|
||||
else:
|
||||
token_logprob = None
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
p.decoded_token: p.logprob
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
p.decoded_token: p.logprob
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
return logprobs
|
||||
|
||||
def create_error_response(
|
||||
@ -164,12 +169,12 @@ class OpenAIServing:
|
||||
raise ValueError("The model `{request.model}` does not exist.")
|
||||
|
||||
def _validate_prompt_and_tokenize(
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
) -> List[int]:
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
) -> Tuple[List[int], str]:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||
if (prompt and prompt_ids):
|
||||
@ -187,6 +192,8 @@ class OpenAIServing:
|
||||
else:
|
||||
input_ids = prompt_ids
|
||||
|
||||
input_text = prompt if prompt is not None else self.tokenizer.decode(
|
||||
prompt_ids)
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
@ -201,4 +208,4 @@ class OpenAIServing:
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.", )
|
||||
else:
|
||||
return input_ids
|
||||
return input_ids, input_text
|
||||
|
||||
Loading…
Reference in New Issue
Block a user