clean api code, remove redundant background task. (#1102)

This commit is contained in:
Roy 2023-09-22 04:25:05 +08:00 committed by GitHub
parent 1ac4ccf73c
commit 2d1e86f1b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 27 deletions

View File

@ -2,7 +2,7 @@ import argparse
import json
from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn
@ -44,14 +44,8 @@ async def generate(request: Request) -> Response:
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await engine.abort(request_id)
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None

View File

@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi
import uvicorn
from fastapi import BackgroundTasks, Request
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
@ -229,9 +229,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(
index: int,
text: str,
@ -291,19 +288,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
media_type="text/event-stream")
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
@ -448,9 +441,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(
index: int,
text: str,
@ -510,19 +500,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
# Streaming response
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
media_type="text/event-stream")
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res