Fix echo/logprob OpenAI completion bug (#3441)

Co-authored-by: Dylan Hawk <dylanwawk@gmail.com>
This commit is contained in:
Dylan Hawk 2024-04-11 15:15:50 -07:00 committed by GitHub
parent 559eb852f8
commit 95e7d4a97c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 29 deletions

View File

@ -742,5 +742,36 @@ number: "1" | "2"
assert content.strip() == ground_truth assert content.strip() == ground_truth
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
model_name: str):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=1)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
assert (completion.choices[0].text is not None
and re.search(r"^" + prompt_text, completion.choices[0].text))
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
assert len(logprobs.tokens) > 5
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@ -63,8 +63,9 @@ class OpenAIServingChat(OpenAIServing):
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
try: try:
token_ids = self._validate_prompt_and_tokenize(request, # Tokenize/detokenize depending on prompt format (string/token list)
prompt=prompt) prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = ( guided_decode_logits_processor = (
@ -78,8 +79,8 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt, sampling_params, result_generator = self.engine.generate(prompt_text, sampling_params,
request_id, token_ids, request_id, prompt_ids,
lora_request) lora_request)
# Streaming response # Streaming response
if request.stream: if request.stream:

View File

@ -136,23 +136,24 @@ class OpenAIServingCompletion(OpenAIServing):
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if prompt_is_tokens: if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize( prompt_formats = self._validate_prompt_and_tokenize(
request, request,
prompt_ids=prompt, prompt_ids=prompt,
truncate_prompt_tokens=sampling_params. truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens) truncate_prompt_tokens)
else: else:
input_ids = self._validate_prompt_and_tokenize( prompt_formats = self._validate_prompt_and_tokenize(
request, request,
prompt=prompt, prompt=prompt,
truncate_prompt_tokens=sampling_params. truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens) truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
generators.append( generators.append(
self.engine.generate(prompt, self.engine.generate(prompt_text,
sampling_params, sampling_params,
f"{request_id}-{i}", f"{request_id}-{i}",
prompt_token_ids=input_ids, prompt_token_ids=prompt_ids,
lora_request=lora_request)) lora_request=lora_request))
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
@ -326,7 +327,8 @@ class OpenAIServingCompletion(OpenAIServing):
output_text = prompt_text output_text = prompt_text
elif request.echo and request.max_tokens > 0: elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
output_text = prompt_text + output.text output_text = prompt_text + output.text
else: else:
token_ids = output.token_ids token_ids = output.token_ids
@ -334,6 +336,9 @@ class OpenAIServingCompletion(OpenAIServing):
output_text = output.text output_text = output.text
if request.logprobs is not None: if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs( logprobs = self._create_logprobs(
token_ids=token_ids, token_ids=token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,

View File

@ -2,7 +2,7 @@ import asyncio
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
from pydantic import conint from pydantic import conint
@ -99,27 +99,32 @@ class OpenAIServing:
last_token_len = 0 last_token_len = 0
if num_output_top_logprobs: if num_output_top_logprobs:
logprobs.top_logprobs = [] logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None: if step_top_logprobs is None:
token_logprob = step_top_logprobs[token_id].logprob token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
logprobs.top_logprobs.append(None)
else: else:
token_logprob = None token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token) logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob) logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
if len(logprobs.text_offset) == 0: if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset) logprobs.text_offset.append(initial_text_offset)
else: else:
logprobs.text_offset.append(logprobs.text_offset[-1] + logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len) last_token_len)
last_token_len = len(token) last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs return logprobs
def create_error_response( def create_error_response(
@ -164,12 +169,12 @@ class OpenAIServing:
raise ValueError("The model `{request.model}` does not exist.") raise ValueError("The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize( def _validate_prompt_and_tokenize(
self, self,
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None, prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None, prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> List[int]: ) -> Tuple[List[int], str]:
if not (prompt or prompt_ids): if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.") raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids): if (prompt and prompt_ids):
@ -187,6 +192,8 @@ class OpenAIServing:
else: else:
input_ids = prompt_ids input_ids = prompt_ids
input_text = prompt if prompt is not None else self.tokenizer.decode(
prompt_ids)
token_num = len(input_ids) token_num = len(input_ids)
if request.max_tokens is None: if request.max_tokens is None:
@ -201,4 +208,4 @@ class OpenAIServing:
f"{request.max_tokens} in the completion). " f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", ) f"Please reduce the length of the messages or completion.", )
else: else:
return input_ids return input_ids, input_text