Non-streaming simple fastapi server (#144)

This commit is contained in:
Zhuohan Li 2023-06-11 01:43:07 +08:00 committed by GitHub
parent 4298374265
commit 5020e1e80c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 20 deletions

View File

@ -233,7 +233,7 @@ async def create_completion(raw_request: Request):
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await server.abort(request_id) await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected") "Client disconnected")
final_res = res final_res = res

View File

@ -3,7 +3,7 @@ import json
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import Response, StreamingResponse
import uvicorn import uvicorn
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
@ -17,19 +17,22 @@ app = FastAPI()
@app.post("/generate") @app.post("/generate")
async def generate_stream(request: Request) -> StreamingResponse: async def generate(request: Request) -> Response:
""" Stream the results of the generation request. """ Stream the results of the generation request.
The request should be a JSON object with the following fields: The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation. - prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details). - other fields: the sampling parameters (See `SamplingParams` for details).
""" """
request_dict = await request.json() request_dict = await request.json()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
results_generator = server.generate(prompt, sampling_params, request_id) results_generator = server.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator: async for request_output in results_generator:
prompt = request_output.prompt prompt = request_output.prompt
@ -37,19 +40,35 @@ async def generate_stream(request: Request) -> StreamingResponse:
prompt + output.text prompt + output.text
for output in request_output.outputs for output in request_output.outputs
] ]
ret = { ret = {"text": text_outputs}
"text": text_outputs,
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None: async def abort_request() -> None:
await server.abort(request_id) await server.abort(request_id)
background_tasks = BackgroundTasks() if stream:
# Abort the request if the client disconnects. background_tasks = BackgroundTasks()
background_tasks.add_task(abort_request) # Abort the request if the client disconnects.
return StreamingResponse(stream_results(), background=background_tasks) background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
prompt = final_output.prompt
text_outputs = [
prompt + output.text
for output in final_output.outputs
]
ret = {"text": text_outputs}
return Response(content=json.dumps(ret))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,15 +1,17 @@
import argparse import argparse
import requests
import json import json
import requests
from typing import Iterable, List
def clear_line(n=1): def clear_line(n:int = 1) -> None:
LINE_UP = '\033[1A' LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K' LINE_CLEAR = '\x1b[2K'
for i in range(n): for i in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True) print(LINE_UP, end=LINE_CLEAR, flush=True)
def http_request(prompt: str, api_url: str, n: int = 1): def post_http_request(prompt: str, api_url: str, n: int = 1,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"} headers = {"User-Agent": "Test Client"}
pload = { pload = {
"prompt": prompt, "prompt": prompt,
@ -17,32 +19,52 @@ def http_request(prompt: str, api_url: str, n: int = 1):
"use_beam_search": True, "use_beam_search": True,
"temperature": 0.0, "temperature": 0.0,
"max_tokens": 16, "max_tokens": 16,
"stream": stream,
} }
response = requests.post(api_url, headers=headers, json=pload, stream=True) response = requests.post(api_url, headers=headers, json=pload, stream=True)
return response
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False,
delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"] output = data["text"]
yield output yield output
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--n", type=int, default=4) parser.add_argument("--n", type=int, default=4)
parser.add_argument("--prompt", type=str, default="San Francisco is a") parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
args = parser.parse_args() args = parser.parse_args()
prompt = args.prompt prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate" api_url = f"http://{args.host}:{args.port}/generate"
n = args.n n = args.n
stream = args.stream
print(f"Prompt: {prompt}\n", flush=True) print(f"Prompt: {prompt}\n", flush=True)
num_printed_lines = 0 response = post_http_request(prompt, api_url, n, stream)
for h in http_request(prompt, api_url, n):
clear_line(num_printed_lines) if stream:
num_printed_lines = 0 num_printed_lines = 0
for i, line in enumerate(h): for h in get_streaming_response(response):
num_printed_lines += 1 clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line}", flush=True)
else:
output = get_response(response)
for i, line in enumerate(output):
print(f"Beam candidate {i}: {line}", flush=True) print(f"Beam candidate {i}: {line}", flush=True)