From 66c54aa9c33555a6b41421d57d3ad6c1bf004ec9 Mon Sep 17 00:00:00 2001 From: Nicolas Basile Date: Tue, 8 Aug 2023 17:43:49 -0700 Subject: [PATCH] Check the max prompt length for the OpenAI completions API (#472) --- vllm/entrypoints/openai/api_server.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 81004d3c..8acea787 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -120,7 +120,7 @@ async def check_length(request, prompt): token_num = len(input_ids) if token_num + request.max_tokens > max_model_len: - return create_error_response( + return input_ids, create_error_response( HTTPStatus.BAD_REQUEST, f"This model's maximum context length is {max_model_len} tokens. " f"However, you requested {request.max_tokens + token_num} tokens " @@ -129,7 +129,7 @@ async def check_length(request, prompt): f"Please reduce the length of the messages or completion.", ) else: - return None + return input_ids, None @app.get("/v1/models") @@ -191,7 +191,7 @@ async def create_chat_completion(raw_request: Request): "logit_bias is not currently supported") prompt = await get_gen_prompt(request) - error_check_ret = await check_length(request, prompt) + token_ids, error_check_ret = await check_length(request, prompt) if error_check_ret is not None: return error_check_ret @@ -215,7 +215,8 @@ async def create_chat_completion(raw_request: Request): except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - result_generator = engine.generate(prompt, sampling_params, request_id) + result_generator = engine.generate(prompt, sampling_params, request_id, + token_ids) async def abort_request() -> None: await engine.abort(request_id) @@ -386,6 +387,11 @@ async def create_completion(raw_request: Request): prompt = request.prompt[0] else: prompt = request.prompt + + token_ids, error_check_ret = await check_length(request, prompt) + if error_check_ret is not None: + return error_check_ret + created_time = int(time.time()) try: sampling_params = SamplingParams( @@ -405,7 +411,8 @@ async def create_completion(raw_request: Request): except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - result_generator = engine.generate(prompt, sampling_params, request_id) + result_generator = engine.generate(prompt, sampling_params, request_id, + token_ids) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search.