From 63e835cbccec62cc34ed27a6133d9a7f4af4a068 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 22 Jan 2024 22:40:31 +0000 Subject: [PATCH] Fix progress bar and allow HTTPS in `benchmark_serving.py` (#2552) --- benchmarks/benchmark_serving.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 6b5dd097..28faa96f 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -101,6 +101,7 @@ async def send_request( output_len: int, best_of: int, use_beam_search: bool, + pbar: tqdm ) -> None: request_start_time = time.perf_counter() @@ -151,6 +152,8 @@ async def send_request( request_end_time = time.perf_counter() request_latency = request_end_time - request_start_time REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) + pbar.update(1) + async def benchmark( @@ -163,13 +166,15 @@ async def benchmark( request_rate: float, ) -> None: tasks: List[asyncio.Task] = [] + pbar = tqdm(total=len(input_requests)) async for request in get_request(input_requests, request_rate): prompt, prompt_len, output_len = request task = asyncio.create_task( send_request(backend, model, api_url, prompt, prompt_len, - output_len, best_of, use_beam_search)) + output_len, best_of, use_beam_search, pbar)) tasks.append(task) - await tqdm.gather(*tasks) + await asyncio.gather(*tasks) + pbar.close() def main(args: argparse.Namespace): @@ -177,7 +182,7 @@ def main(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) - api_url = f"http://{args.host}:{args.port}{args.endpoint}" + api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}" tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) @@ -212,6 +217,7 @@ if __name__ == "__main__": type=str, default="vllm", choices=["vllm", "tgi"]) + parser.add_argument("--protocol", type=str, default="http", choices=["http", "https"]) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--endpoint", type=str, default="/generate")