From 5020e1e80c33e57ec9d4f7e17f35bc87b1a7cb01 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 11 Jun 2023 01:43:07 +0800 Subject: [PATCH] Non-streaming simple fastapi server (#144) --- .../entrypoints/openai/openai_frontend.py | 2 +- .../entrypoints/simple_fastapi_frontend.py | 39 +++++++++++++----- examples/simple_fastapi_client.py | 40 ++++++++++++++----- 3 files changed, 61 insertions(+), 20 deletions(-) diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index acf09b21..125537fb 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -233,7 +233,7 @@ async def create_completion(raw_request: Request): async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await server.abort(request_id) + await abort_request() return create_error_response(HTTPStatus.BAD_REQUEST, "Client disconnected") final_res = res diff --git a/cacheflow/entrypoints/simple_fastapi_frontend.py b/cacheflow/entrypoints/simple_fastapi_frontend.py index 1438851c..07933003 100644 --- a/cacheflow/entrypoints/simple_fastapi_frontend.py +++ b/cacheflow/entrypoints/simple_fastapi_frontend.py @@ -3,7 +3,7 @@ import json from typing import AsyncGenerator from fastapi import BackgroundTasks, FastAPI, Request -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse import uvicorn from cacheflow.sampling_params import SamplingParams @@ -17,19 +17,22 @@ app = FastAPI() @app.post("/generate") -async def generate_stream(request: Request) -> StreamingResponse: +async def generate(request: Request) -> Response: """ Stream the results of the generation request. The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() results_generator = server.generate(prompt, sampling_params, request_id) + # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: prompt = request_output.prompt @@ -37,19 +40,35 @@ async def generate_stream(request: Request) -> StreamingResponse: prompt + output.text for output in request_output.outputs ] - ret = { - "text": text_outputs, - "error": 0, - } + ret = {"text": text_outputs} yield (json.dumps(ret) + "\0").encode("utf-8") async def abort_request() -> None: await server.abort(request_id) - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) - return StreamingResponse(stream_results(), background=background_tasks) + if stream: + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) + return StreamingResponse(stream_results(), background=background_tasks) + + # Non-streaming case + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await server.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + text_outputs = [ + prompt + output.text + for output in final_output.outputs + ] + ret = {"text": text_outputs} + return Response(content=json.dumps(ret)) if __name__ == "__main__": diff --git a/examples/simple_fastapi_client.py b/examples/simple_fastapi_client.py index d7d9d355..34a87cb4 100644 --- a/examples/simple_fastapi_client.py +++ b/examples/simple_fastapi_client.py @@ -1,15 +1,17 @@ import argparse -import requests import json +import requests +from typing import Iterable, List -def clear_line(n=1): +def clear_line(n:int = 1) -> None: LINE_UP = '\033[1A' LINE_CLEAR = '\x1b[2K' for i in range(n): print(LINE_UP, end=LINE_CLEAR, flush=True) -def http_request(prompt: str, api_url: str, n: int = 1): +def post_http_request(prompt: str, api_url: str, n: int = 1, + stream: bool = False) -> requests.Response: headers = {"User-Agent": "Test Client"} pload = { "prompt": prompt, @@ -17,32 +19,52 @@ def http_request(prompt: str, api_url: str, n: int = 1): "use_beam_search": True, "temperature": 0.0, "max_tokens": 16, + "stream": stream, } response = requests.post(api_url, headers=headers, json=pload, stream=True) + return response - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + +def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, + delimiter=b"\0"): if chunk: data = json.loads(chunk.decode("utf-8")) output = data["text"] yield output +def get_response(response: requests.Response) -> List[str]: + data = json.loads(response.content) + output = data["text"] + return output + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8001) parser.add_argument("--n", type=int, default=4) parser.add_argument("--prompt", type=str, default="San Francisco is a") + parser.add_argument("--stream", action="store_true") args = parser.parse_args() prompt = args.prompt api_url = f"http://{args.host}:{args.port}/generate" n = args.n + stream = args.stream print(f"Prompt: {prompt}\n", flush=True) - num_printed_lines = 0 - for h in http_request(prompt, api_url, n): - clear_line(num_printed_lines) + response = post_http_request(prompt, api_url, n, stream) + + if stream: num_printed_lines = 0 - for i, line in enumerate(h): - num_printed_lines += 1 + for h in get_streaming_response(response): + clear_line(num_printed_lines) + num_printed_lines = 0 + for i, line in enumerate(h): + num_printed_lines += 1 + print(f"Beam candidate {i}: {line}", flush=True) + else: + output = get_response(response) + for i, line in enumerate(output): print(f"Beam candidate {i}: {line}", flush=True)