clean api code, remove redundant background task. (#1102)
This commit is contained in:
parent
1ac4ccf73c
commit
2d1e86f1b1
@ -2,7 +2,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fastapi import BackgroundTasks, FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
@ -44,14 +44,8 @@ async def generate(request: Request) -> Response:
|
|||||||
ret = {"text": text_outputs}
|
ret = {"text": text_outputs}
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
async def abort_request() -> None:
|
|
||||||
await engine.abort(request_id)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
background_tasks = BackgroundTasks()
|
return StreamingResponse(stream_results())
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
background_tasks.add_task(abort_request)
|
|
||||||
return StreamingResponse(stream_results(), background=background_tasks)
|
|
||||||
|
|
||||||
# Non-streaming case
|
# Non-streaming case
|
||||||
final_output = None
|
final_output = None
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import BackgroundTasks, Request
|
from fastapi import Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
@ -229,9 +229,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||||
token_ids)
|
token_ids)
|
||||||
|
|
||||||
async def abort_request() -> None:
|
|
||||||
await engine.abort(request_id)
|
|
||||||
|
|
||||||
def create_stream_response_json(
|
def create_stream_response_json(
|
||||||
index: int,
|
index: int,
|
||||||
text: str,
|
text: str,
|
||||||
@ -291,19 +288,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if request.stream:
|
if request.stream:
|
||||||
background_tasks = BackgroundTasks()
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
background_tasks.add_task(abort_request)
|
|
||||||
return StreamingResponse(completion_stream_generator(),
|
return StreamingResponse(completion_stream_generator(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream")
|
||||||
background=background_tasks)
|
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
if await raw_request.is_disconnected():
|
||||||
# Abort the request if the client disconnects.
|
# Abort the request if the client disconnects.
|
||||||
await abort_request()
|
await engine.abort(request_id)
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"Client disconnected")
|
"Client disconnected")
|
||||||
final_res = res
|
final_res = res
|
||||||
@ -448,9 +441,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
and (request.best_of is None or request.n == request.best_of)
|
and (request.best_of is None or request.n == request.best_of)
|
||||||
and not request.use_beam_search)
|
and not request.use_beam_search)
|
||||||
|
|
||||||
async def abort_request() -> None:
|
|
||||||
await engine.abort(request_id)
|
|
||||||
|
|
||||||
def create_stream_response_json(
|
def create_stream_response_json(
|
||||||
index: int,
|
index: int,
|
||||||
text: str,
|
text: str,
|
||||||
@ -510,19 +500,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if stream:
|
if stream:
|
||||||
background_tasks = BackgroundTasks()
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
background_tasks.add_task(abort_request)
|
|
||||||
return StreamingResponse(completion_stream_generator(),
|
return StreamingResponse(completion_stream_generator(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream")
|
||||||
background=background_tasks)
|
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
if await raw_request.is_disconnected():
|
||||||
# Abort the request if the client disconnects.
|
# Abort the request if the client disconnects.
|
||||||
await abort_request()
|
await engine.abort(request_id)
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"Client disconnected")
|
"Client disconnected")
|
||||||
final_res = res
|
final_res = res
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user