[Frontend] Add progress reporting to run_batch.py (#8060)

Co-authored-by: Adam Lugowski <adam.lugowski@parasail.io>
This commit is contained in:
Adam Lugowski 2024-09-09 11:16:37 -07:00 committed by GitHub
parent 08287ef675
commit 58fcc8545a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,9 +1,11 @@
import asyncio
from io import StringIO
from typing import Awaitable, Callable, List
from typing import Awaitable, Callable, List, Optional
import aiohttp
import torch
from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -78,6 +80,38 @@ def parse_args():
return parser.parse_args()
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._pbar: Optional[tqdm] = None
def submitted(self):
self._total += 1
def completed(self):
if self._pbar:
self._pbar.update()
def pbar(self) -> tqdm:
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
self._pbar = tqdm(total=self._total,
unit="req",
desc="Running batch",
mininterval=5,
disable=not enable_tqdm,
bar_format=_BAR_FORMAT)
return self._pbar
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
@ -102,7 +136,8 @@ async def write_file(path_or_url: str, data: str) -> None:
async def run_request(serving_engine_func: Callable,
request: BatchRequestInput) -> BatchRequestOutput:
request: BatchRequestInput,
tracker: BatchProgressTracker) -> BatchRequestOutput:
response = await serving_engine_func(request.body)
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
@ -125,6 +160,7 @@ async def run_request(serving_engine_func: Callable,
else:
raise ValueError("Request must not be sent in stream mode")
tracker.completed()
return batch_output
@ -164,6 +200,9 @@ async def main(args):
request_logger=request_logger,
)
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently".
response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
@ -178,16 +217,19 @@ async def main(args):
if request.url == "/v1/chat/completions":
response_futures.append(
run_request(openai_serving_chat.create_chat_completion,
request))
request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
response_futures.append(
run_request(openai_serving_embedding.create_embedding,
request))
run_request(openai_serving_embedding.create_embedding, request,
tracker))
tracker.submitted()
else:
raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
"supported in the batch endpoint.")
responses = await asyncio.gather(*response_futures)
with tracker.pbar():
responses = await asyncio.gather(*response_futures)
output_buffer = StringIO()
for response in responses: