From 2d1e86f1b15396119321cfb3a77acde72b0c08ee Mon Sep 17 00:00:00 2001 From: Roy Date: Fri, 22 Sep 2023 04:25:05 +0800 Subject: [PATCH] clean api code, remove redundant background task. (#1102) --- vllm/entrypoints/api_server.py | 10 ++-------- vllm/entrypoints/openai/api_server.py | 24 +++++------------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index e9d3342a..5e63a02c 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -2,7 +2,7 @@ import argparse import json from typing import AsyncGenerator -from fastapi import BackgroundTasks, FastAPI, Request +from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import uvicorn @@ -44,14 +44,8 @@ async def generate(request: Request) -> Response: ret = {"text": text_outputs} yield (json.dumps(ret) + "\0").encode("utf-8") - async def abort_request() -> None: - await engine.abort(request_id) - if stream: - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) - return StreamingResponse(stream_results(), background=background_tasks) + return StreamingResponse(stream_results()) # Non-streaming case final_output = None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bc827b34..eed3bebb 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union import fastapi import uvicorn -from fastapi import BackgroundTasks, Request +from fastapi import Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse @@ -229,9 +229,6 @@ async def create_chat_completion(request: ChatCompletionRequest, result_generator = engine.generate(prompt, sampling_params, request_id, token_ids) - async def abort_request() -> None: - await engine.abort(request_id) - def create_stream_response_json( index: int, text: str, @@ -291,19 +288,15 @@ async def create_chat_completion(request: ChatCompletionRequest, # Streaming response if request.stream: - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) return StreamingResponse(completion_stream_generator(), - media_type="text/event-stream", - background=background_tasks) + media_type="text/event-stream") # Non-streaming response final_res: RequestOutput = None async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await abort_request() + await engine.abort(request_id) return create_error_response(HTTPStatus.BAD_REQUEST, "Client disconnected") final_res = res @@ -448,9 +441,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): and (request.best_of is None or request.n == request.best_of) and not request.use_beam_search) - async def abort_request() -> None: - await engine.abort(request_id) - def create_stream_response_json( index: int, text: str, @@ -510,19 +500,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request): # Streaming response if stream: - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) return StreamingResponse(completion_stream_generator(), - media_type="text/event-stream", - background=background_tasks) + media_type="text/event-stream") # Non-streaming response final_res: RequestOutput = None async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await abort_request() + await engine.abort(request_id) return create_error_response(HTTPStatus.BAD_REQUEST, "Client disconnected") final_res = res