Added echo function to OpenAI API server. (#1504)
This commit is contained in:
parent
7c600440f7
commit
665cbcec4b
@ -160,16 +160,26 @@ async def show_available_models():
|
|||||||
return ModelList(data=model_cards)
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
|
||||||
def create_logprobs(token_ids: List[int],
|
def create_logprobs(
|
||||||
id_logprobs: List[Dict[int, float]],
|
token_ids: List[int],
|
||||||
initial_text_offset: int = 0) -> LogProbs:
|
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
|
||||||
|
num_output_top_logprobs: Optional[int] = None,
|
||||||
|
initial_text_offset: int = 0,
|
||||||
|
) -> LogProbs:
|
||||||
"""Create OpenAI-style logprobs."""
|
"""Create OpenAI-style logprobs."""
|
||||||
logprobs = LogProbs()
|
logprobs = LogProbs()
|
||||||
last_token_len = 0
|
last_token_len = 0
|
||||||
for token_id, id_logprob in zip(token_ids, id_logprobs):
|
if num_output_top_logprobs:
|
||||||
|
logprobs.top_logprobs = []
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
step_top_logprobs = top_logprobs[i]
|
||||||
|
if step_top_logprobs is not None:
|
||||||
|
token_logprob = step_top_logprobs[token_id]
|
||||||
|
else:
|
||||||
|
token_logprob = None
|
||||||
token = tokenizer.convert_ids_to_tokens(token_id)
|
token = tokenizer.convert_ids_to_tokens(token_id)
|
||||||
logprobs.tokens.append(token)
|
logprobs.tokens.append(token)
|
||||||
logprobs.token_logprobs.append(id_logprob[token_id])
|
logprobs.token_logprobs.append(token_logprob)
|
||||||
if len(logprobs.text_offset) == 0:
|
if len(logprobs.text_offset) == 0:
|
||||||
logprobs.text_offset.append(initial_text_offset)
|
logprobs.text_offset.append(initial_text_offset)
|
||||||
else:
|
else:
|
||||||
@ -177,10 +187,11 @@ def create_logprobs(token_ids: List[int],
|
|||||||
last_token_len)
|
last_token_len)
|
||||||
last_token_len = len(token)
|
last_token_len = len(token)
|
||||||
|
|
||||||
logprobs.top_logprobs.append({
|
if num_output_top_logprobs:
|
||||||
tokenizer.convert_ids_to_tokens(i): p
|
logprobs.top_logprobs.append({
|
||||||
for i, p in id_logprob.items()
|
tokenizer.convert_ids_to_tokens(i): p
|
||||||
})
|
for i, p in step_top_logprobs.items()
|
||||||
|
} if step_top_logprobs else None)
|
||||||
return logprobs
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
@ -371,8 +382,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
for the API specification. This API mimics the OpenAI Completion API.
|
for the API specification. This API mimics the OpenAI Completion API.
|
||||||
|
|
||||||
NOTE: Currently we do not support the following features:
|
NOTE: Currently we do not support the following features:
|
||||||
- echo (since the vLLM engine does not currently support
|
|
||||||
getting the logprobs of prompt tokens)
|
|
||||||
- suffix (the language models we currently support do not support
|
- suffix (the language models we currently support do not support
|
||||||
suffix)
|
suffix)
|
||||||
- logit_bias (to be supported by vLLM engine)
|
- logit_bias (to be supported by vLLM engine)
|
||||||
@ -383,11 +392,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
if request.echo:
|
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
||||||
# We do not support echo since the vLLM engine does not
|
echo_without_generation = request.echo and request.max_tokens == 0
|
||||||
# currently support getting the logprobs of prompt tokens.
|
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
|
||||||
"echo is not currently supported")
|
|
||||||
|
|
||||||
if request.suffix is not None:
|
if request.suffix is not None:
|
||||||
# The language models we currently support do not support suffix.
|
# The language models we currently support do not support suffix.
|
||||||
@ -443,9 +449,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
stop=request.stop,
|
stop=request.stop,
|
||||||
stop_token_ids=request.stop_token_ids,
|
stop_token_ids=request.stop_token_ids,
|
||||||
ignore_eos=request.ignore_eos,
|
ignore_eos=request.ignore_eos,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens
|
||||||
|
if not echo_without_generation else 1,
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
use_beam_search=request.use_beam_search,
|
use_beam_search=request.use_beam_search,
|
||||||
|
prompt_logprobs=request.logprobs if request.echo else None,
|
||||||
skip_special_tokens=request.skip_special_tokens,
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
)
|
)
|
||||||
@ -495,24 +503,42 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||||
previous_texts = [""] * request.n
|
previous_texts = [""] * request.n
|
||||||
previous_num_tokens = [0] * request.n
|
previous_num_tokens = [0] * request.n
|
||||||
|
has_echoed = [False] * request.n
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
res: RequestOutput
|
res: RequestOutput
|
||||||
for output in res.outputs:
|
for output in res.outputs:
|
||||||
i = output.index
|
i = output.index
|
||||||
delta_text = output.text[len(previous_texts[i]):]
|
delta_text = output.text[len(previous_texts[i]):]
|
||||||
|
token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||||
|
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
||||||
|
offsets = len(previous_texts[i])
|
||||||
|
if request.echo and not has_echoed[i]:
|
||||||
|
if not echo_without_generation:
|
||||||
|
delta_text = res.prompt + delta_text
|
||||||
|
token_ids = res.prompt_token_ids + token_ids
|
||||||
|
top_logprobs = res.prompt_logprobs + top_logprobs
|
||||||
|
else:
|
||||||
|
delta_text = res.prompt
|
||||||
|
token_ids = res.prompt_token_ids
|
||||||
|
top_logprobs = res.prompt_logprobs
|
||||||
|
has_echoed[i] = True
|
||||||
if request.logprobs is not None:
|
if request.logprobs is not None:
|
||||||
logprobs = create_logprobs(
|
logprobs = create_logprobs(
|
||||||
output.token_ids[previous_num_tokens[i]:],
|
token_ids=token_ids,
|
||||||
output.logprobs[previous_num_tokens[i]:],
|
top_logprobs=top_logprobs,
|
||||||
len(previous_texts[i]))
|
num_output_top_logprobs=request.logprobs,
|
||||||
|
initial_text_offset=offsets,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
previous_texts[i] = output.text
|
previous_texts[i] = output.text
|
||||||
previous_num_tokens[i] = len(output.token_ids)
|
previous_num_tokens[i] = len(output.token_ids)
|
||||||
|
finish_reason = output.finish_reason
|
||||||
response_json = create_stream_response_json(
|
response_json = create_stream_response_json(
|
||||||
index=i,
|
index=i,
|
||||||
text=delta_text,
|
text=delta_text,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
finish_reason=finish_reason,
|
||||||
)
|
)
|
||||||
yield f"data: {response_json}\n\n"
|
yield f"data: {response_json}\n\n"
|
||||||
if output.finish_reason is not None:
|
if output.finish_reason is not None:
|
||||||
@ -551,14 +577,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
final_res = res
|
final_res = res
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
choices = []
|
choices = []
|
||||||
|
prompt_token_ids = final_res.prompt_token_ids
|
||||||
|
prompt_logprobs = final_res.prompt_logprobs
|
||||||
|
prompt_text = final_res.prompt
|
||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
if request.logprobs is not None:
|
if request.logprobs is not None:
|
||||||
logprobs = create_logprobs(output.token_ids, output.logprobs)
|
if not echo_without_generation:
|
||||||
|
token_ids = output.token_ids
|
||||||
|
top_logprobs = output.logprobs
|
||||||
|
if request.echo:
|
||||||
|
token_ids = prompt_token_ids + token_ids
|
||||||
|
top_logprobs = prompt_logprobs + top_logprobs
|
||||||
|
else:
|
||||||
|
token_ids = prompt_token_ids
|
||||||
|
top_logprobs = prompt_logprobs
|
||||||
|
logprobs = create_logprobs(
|
||||||
|
token_ids=token_ids,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
num_output_top_logprobs=request.logprobs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
if not echo_without_generation:
|
||||||
|
output_text = output.text
|
||||||
|
if request.echo:
|
||||||
|
output_text = prompt_text + output_text
|
||||||
|
else:
|
||||||
|
output_text = prompt_text
|
||||||
choice_data = CompletionResponseChoice(
|
choice_data = CompletionResponseChoice(
|
||||||
index=output.index,
|
index=output.index,
|
||||||
text=output.text,
|
text=output_text,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -106,8 +106,7 @@ class LogProbs(BaseModel):
|
|||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
tokens: List[str] = Field(default_factory=list)
|
tokens: List[str] = Field(default_factory=list)
|
||||||
top_logprobs: List[Optional[Dict[str,
|
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
|
||||||
float]]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseChoice(BaseModel):
|
class CompletionResponseChoice(BaseModel):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user