Support OpenAI API server in benchmark_serving.py (#2172)

This commit is contained in:
Harry Mellor 2024-01-19 04:34:08 +00:00 committed by GitHub
parent dd7e8f5f64
commit 2709c0009a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 32 deletions

3
.gitignore vendored
View File

@ -181,3 +181,6 @@ _build/
# hip files generated by PyTorch # hip files generated by PyTorch
*.hip *.hip
*_hip* *_hip*
# Benchmark dataset
*.json

View File

@ -24,6 +24,7 @@ from typing import AsyncGenerator, List, Tuple
import aiohttp import aiohttp
import numpy as np import numpy as np
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
@ -40,15 +41,10 @@ def sample_requests(
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [ dataset = [(data["conversations"][0]["value"],
(data["conversations"][0]["value"], data["conversations"][1]["value"]) data["conversations"][1]["value"]) for data in dataset]
for data in dataset
]
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset] prompts = [prompt for prompt, _ in dataset]
@ -98,6 +94,7 @@ async def get_request(
async def send_request( async def send_request(
backend: str, backend: str,
model: str,
api_url: str, api_url: str,
prompt: str, prompt: str,
prompt_len: int, prompt_len: int,
@ -120,6 +117,8 @@ async def send_request(
"ignore_eos": True, "ignore_eos": True,
"stream": False, "stream": False,
} }
if model is not None:
pload["model"] = model
elif backend == "tgi": elif backend == "tgi":
assert not use_beam_search assert not use_beam_search
params = { params = {
@ -137,7 +136,8 @@ async def send_request(
timeout = aiohttp.ClientTimeout(total=3 * 3600) timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
while True: while True:
async with session.post(api_url, headers=headers, json=pload) as response: async with session.post(api_url, headers=headers,
json=pload) as response:
chunks = [] chunks = []
async for chunk, _ in response.content.iter_chunks(): async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk) chunks.append(chunk)
@ -155,6 +155,7 @@ async def send_request(
async def benchmark( async def benchmark(
backend: str, backend: str,
model: str,
api_url: str, api_url: str,
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
best_of: int, best_of: int,
@ -164,11 +165,11 @@ async def benchmark(
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate): async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request prompt, prompt_len, output_len = request
task = asyncio.create_task(send_request(backend, api_url, prompt, task = asyncio.create_task(
prompt_len, output_len, send_request(backend, model, api_url, prompt, prompt_len,
best_of, use_beam_search)) output_len, best_of, use_beam_search))
tasks.append(task) tasks.append(task)
await asyncio.gather(*tasks) await tqdm.gather(*tasks)
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
@ -176,13 +177,15 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate" api_url = f"http://{args.host}:{args.port}{args.endpoint}"
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of, asyncio.run(
args.use_beam_search, args.request_rate)) benchmark(args.backend, args.model, api_url, input_requests,
args.best_of, args.use_beam_search, args.request_rate))
benchmark_end_time = time.perf_counter() benchmark_end_time = time.perf_counter()
benchmark_time = benchmark_end_time - benchmark_start_time benchmark_time = benchmark_end_time - benchmark_start_time
print(f"Total time: {benchmark_time:.2f} s") print(f"Total time: {benchmark_time:.2f} s")
@ -196,10 +199,8 @@ def main(args: argparse.Namespace):
for prompt_len, output_len, latency in REQUEST_LATENCY for prompt_len, output_len, latency in REQUEST_LATENCY
]) ])
print(f"Average latency per token: {avg_per_token_latency:.2f} s") print(f"Average latency per token: {avg_per_token_latency:.2f} s")
avg_per_output_token_latency = np.mean([ avg_per_output_token_latency = np.mean(
latency / output_len [latency / output_len for _, output_len, latency in REQUEST_LATENCY])
for _, output_len, latency in REQUEST_LATENCY
])
print("Average latency per output token: " print("Average latency per output token: "
f"{avg_per_output_token_latency:.2f} s") f"{avg_per_output_token_latency:.2f} s")
@ -207,27 +208,42 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput.") description="Benchmark the online serving throughput.")
parser.add_argument("--backend", type=str, default="vllm", parser.add_argument("--backend",
type=str,
default="vllm",
choices=["vllm", "tgi"]) choices=["vllm", "tgi"])
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--endpoint", type=str, default="/generate")
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--dataset",
type=str,
required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--tokenizer", type=str, required=True, parser.add_argument("--tokenizer",
type=str,
required=True,
help="Name or path of the tokenizer.") help="Name or path of the tokenizer.")
parser.add_argument("--best-of", type=int, default=1, parser.add_argument("--best-of",
type=int,
default=1,
help="Generates `best_of` sequences per prompt and " help="Generates `best_of` sequences per prompt and "
"returns the best one.") "returns the best one.")
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000, parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.") help="Number of prompts to process.")
parser.add_argument("--request-rate", type=float, default=float("inf"), parser.add_argument("--request-rate",
type=float,
default=float("inf"),
help="Number of requests per second. If this is inf, " help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. " "then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize " "Otherwise, we use Poisson process to synthesize "
"the request arrival times.") "the request arrival times.")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true', parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface') help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)