clean api code, remove redundant background task. (#1102)

This commit is contained in:
Roy 2023-09-22 04:25:05 +08:00 committed by GitHub
parent 1ac4ccf73c
commit 2d1e86f1b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 27 deletions

View File

@ -2,7 +2,7 @@ import argparse
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn import uvicorn
@ -44,14 +44,8 @@ async def generate(request: Request) -> Response:
ret = {"text": text_outputs} ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await engine.abort(request_id)
if stream: if stream:
background_tasks = BackgroundTasks() return StreamingResponse(stream_results())
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case # Non-streaming case
final_output = None final_output = None

View File

@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import uvicorn import uvicorn
from fastapi import BackgroundTasks, Request from fastapi import Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
@ -229,9 +229,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
result_generator = engine.generate(prompt, sampling_params, request_id, result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids) token_ids)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json( def create_stream_response_json(
index: int, index: int,
text: str, text: str,
@ -291,19 +288,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
# Streaming response # Streaming response
if request.stream: if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(), return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream", media_type="text/event-stream")
background=background_tasks)
# Non-streaming response # Non-streaming response
final_res: RequestOutput = None final_res: RequestOutput = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await abort_request() await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected") "Client disconnected")
final_res = res final_res = res
@ -448,9 +441,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
and (request.best_of is None or request.n == request.best_of) and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search) and not request.use_beam_search)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json( def create_stream_response_json(
index: int, index: int,
text: str, text: str,
@ -510,19 +500,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
# Streaming response # Streaming response
if stream: if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(), return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream", media_type="text/event-stream")
background=background_tasks)
# Non-streaming response # Non-streaming response
final_res: RequestOutput = None final_res: RequestOutput = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await abort_request() await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected") "Client disconnected")
final_res = res final_res = res