Serving Benchmark Refactoring (#2433)
This commit is contained in:
parent
563836496a
commit
a4211a4dc3
@ -6,15 +6,16 @@ set -o pipefail
|
||||
# cd into parent directory of this file
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
||||
|
||||
(wget && curl) || (apt-get update && apt-get install -y wget curl)
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
# run benchmarks and upload the result to buildkite
|
||||
# run python-based benchmarks and upload the result to buildkite
|
||||
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
|
||||
bench_latency_exit_code=$?
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
|
||||
bench_throughput_exit_code=$?
|
||||
|
||||
# run server-based benchmarks and upload the result to buildkite
|
||||
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
|
||||
server_pid=$!
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
@ -22,11 +23,14 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r
|
||||
# wait for server to start, timeout after 600 seconds
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend openai \
|
||||
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer meta-llama/Llama-2-7b-chat-hf 2>&1 | tee benchmark_serving.txt
|
||||
--tokenizer meta-llama/Llama-2-7b-chat-hf \
|
||||
--save-result \
|
||||
2>&1 | tee benchmark_serving.txt
|
||||
bench_serving_exit_code=$?
|
||||
kill $server_pid
|
||||
|
||||
@ -44,7 +48,7 @@ sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
|
||||
echo "### Serving Benchmarks" >> benchmark_results.md
|
||||
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
||||
echo "" >> benchmark_results.md
|
||||
tail -n 5 benchmark_serving.txt >> benchmark_results.md # last 5 lines
|
||||
tail -n 13 benchmark_serving.txt >> benchmark_results.md # last 13 lines
|
||||
|
||||
# upload the results to buildkite
|
||||
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
||||
@ -61,3 +65,5 @@ fi
|
||||
if [ $bench_serving_exit_code -ne 0 ]; then
|
||||
exit $bench_serving_exit_code
|
||||
fi
|
||||
|
||||
/workspace/buildkite-agent artifact upload openai-*.json
|
||||
|
||||
284
benchmarks/backend_request_func.py
Normal file
284
benchmarks/backend_request_func.py
Normal file
@ -0,0 +1,284 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
prompt: str
|
||||
api_url: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
best_of: int = 1
|
||||
use_beam_search: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0
|
||||
ttft: float = 0
|
||||
prompt_len: int = 0
|
||||
|
||||
|
||||
async def async_request_tgi(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
params = {
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"do_sample": True,
|
||||
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
||||
}
|
||||
payload = {
|
||||
"inputs": request_func_input.prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for data in response.content.iter_any():
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
body = data.decode("utf-8").lstrip("data:")
|
||||
output.generated_text = json.loads(body)["generated_text"]
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_vllm(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"prompt": request_func_input.prompt,
|
||||
"n": 1,
|
||||
"best_of": request_func_input.best_of,
|
||||
"use_beam_search": request_func_input.use_beam_search,
|
||||
"temperature": 0.0 if request_func_input.use_beam_search else 1.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"ignore_eos": True,
|
||||
"stream": True,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for data in response.content.iter_any():
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
# When streaming, '\0' is appended to the end of the response.
|
||||
body = data.decode("utf-8").strip("\0")
|
||||
output.generated_text = json.loads(
|
||||
body)["text"][0][len(request_func_input.prompt):]
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_trt_llm(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
assert request_func_input.best_of == 1
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
ttft = 0
|
||||
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as resp:
|
||||
if resp.status == 200:
|
||||
async for data in resp.content.iter_any():
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
body = data.decode("utf-8").lstrip("data:")
|
||||
output.generated_text = json.loads(body)["text_output"]
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_deepspeed_mii(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
assert not request_func_input.use_beam_search
|
||||
|
||||
payload = {
|
||||
"prompts": request_func_input.prompt,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"ignore_eos": True,
|
||||
"do_sample": True,
|
||||
"temperature":
|
||||
0.01, # deepspeed-mii does not accept 0.0 temperature.
|
||||
"top_p": 1.0,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder.
|
||||
# https://github.com/microsoft/DeepSpeed-MII/pull/311
|
||||
output.ttft = 0
|
||||
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=request_func_input.api_url,
|
||||
json=payload) as resp:
|
||||
if resp.status == 200:
|
||||
parsed_resp = await resp.json()
|
||||
output.latency = time.perf_counter() - st
|
||||
output.generated_text = parsed_resp[0]["generated_text"]
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("v1/completions")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk = chunk.decode("utf-8").lstrip("data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
body = json.loads(chunk)
|
||||
generated_text += body["choices"][0]["text"]
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
"tgi": async_request_tgi,
|
||||
"vllm": async_request_vllm,
|
||||
"deepspeed-mii": async_request_deepspeed_mii,
|
||||
"openai": async_request_openai_completions,
|
||||
"tensorrt-llm": async_request_trt_llm,
|
||||
}
|
||||
@ -20,16 +20,36 @@ import asyncio
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
# (prompt len, output len, latency)
|
||||
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
|
||||
from backend_request_func import (
|
||||
ASYNC_REQUEST_FUNCS,
|
||||
RequestFuncInput,
|
||||
RequestFuncOutput,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkMetrics:
|
||||
completed: int
|
||||
total_input: int
|
||||
total_output: int
|
||||
request_throughput: float
|
||||
input_throughput: float
|
||||
output_throughput: float
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
p99_ttft_ms: float
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
p99_tpot_ms: float
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@ -46,6 +66,11 @@ def sample_requests(
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# some of these will be filtered out, so sample more than we need
|
||||
sampled_indices = random.sample(range(len(dataset)),
|
||||
int(num_requests * 1.2))
|
||||
dataset = [dataset[i] for i in sampled_indices]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
prompt_token_ids = tokenizer(prompts).input_ids
|
||||
@ -92,80 +117,125 @@ async def get_request(
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
async def send_request(backend: str, model: str, api_url: str, prompt: str,
|
||||
prompt_len: int, output_len: int, best_of: int,
|
||||
use_beam_search: bool, pbar: tqdm) -> None:
|
||||
request_start_time = time.perf_counter()
|
||||
def calculate_metrics(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
outputs: List[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> BenchmarkMetrics:
|
||||
total_output = 0
|
||||
total_input = 0
|
||||
completed = 0
|
||||
per_token_latencies = []
|
||||
ttfts = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
output_len = len(tokenizer.encode(outputs[i].generated_text))
|
||||
total_output += output_len
|
||||
total_input += input_requests[i][1]
|
||||
per_token_latencies.append(outputs[i].latency / output_len)
|
||||
ttfts.append(outputs[i].ttft)
|
||||
completed += 1
|
||||
|
||||
headers = {"User-Agent": "Benchmark Client"}
|
||||
if backend == "vllm":
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"n": 1,
|
||||
"best_of": best_of,
|
||||
"use_beam_search": use_beam_search,
|
||||
"temperature": 0.0 if use_beam_search else 1.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"stream": False,
|
||||
}
|
||||
if model is not None:
|
||||
pload["model"] = model
|
||||
elif backend == "tgi":
|
||||
assert not use_beam_search
|
||||
params = {
|
||||
"best_of": best_of,
|
||||
"max_new_tokens": output_len,
|
||||
"do_sample": True,
|
||||
}
|
||||
pload = {
|
||||
"inputs": prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
metrics = BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
total_output=total_output,
|
||||
request_throughput=completed / dur_s,
|
||||
input_throughput=total_input / dur_s,
|
||||
output_throughput=total_output / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts) * 1000,
|
||||
median_ttft_ms=np.median(ttfts) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(per_token_latencies) * 1000,
|
||||
median_tpot_ms=np.median(per_token_latencies) * 1000,
|
||||
p99_tpot_ms=np.percentile(per_token_latencies, 99) * 1000,
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
while True:
|
||||
async with session.post(api_url, headers=headers,
|
||||
json=pload) as response:
|
||||
chunks = []
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunks.append(chunk)
|
||||
output = b"".join(chunks).decode("utf-8")
|
||||
output = json.loads(output)
|
||||
|
||||
# Re-send the request if it failed.
|
||||
if "error" not in output:
|
||||
break
|
||||
|
||||
request_end_time = time.perf_counter()
|
||||
request_latency = request_end_time - request_start_time
|
||||
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
||||
pbar.update(1)
|
||||
return metrics
|
||||
|
||||
|
||||
async def benchmark(
|
||||
backend: str,
|
||||
model: str,
|
||||
api_url: str,
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
best_of: int,
|
||||
use_beam_search: bool,
|
||||
request_rate: float,
|
||||
) -> None:
|
||||
tasks: List[asyncio.Task] = []
|
||||
pbar = tqdm(total=len(input_requests))
|
||||
disable_tqdm: bool,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS.get(backend)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = request
|
||||
task = asyncio.create_task(
|
||||
send_request(backend, model, api_url, prompt, prompt_len,
|
||||
output_len, best_of, use_beam_search, pbar))
|
||||
tasks.append(task)
|
||||
await asyncio.gather(*tasks)
|
||||
pbar.close()
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)))
|
||||
outputs = await asyncio.gather(*tasks)
|
||||
|
||||
if not disable_tqdm:
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
metrics = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
print(f"Successful requests: {metrics.completed}")
|
||||
print(f"Benchmark duration: {benchmark_duration:2f} s")
|
||||
print(f"Total input tokens: {metrics.total_input}")
|
||||
print(f"Total generated tokens: {metrics.total_output}")
|
||||
print(f"Request throughput: {metrics.request_throughput:.2f} requests/s")
|
||||
print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s")
|
||||
print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s")
|
||||
print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms")
|
||||
print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms")
|
||||
print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms")
|
||||
print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms")
|
||||
print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms")
|
||||
print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms")
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_inthroughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -173,77 +243,145 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
|
||||
tokenizer = get_tokenizer(args.tokenizer,
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
else:
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
asyncio.run(
|
||||
benchmark(args.backend, args.model, api_url, input_requests,
|
||||
args.best_of, args.use_beam_search, args.request_rate))
|
||||
benchmark_end_time = time.perf_counter()
|
||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||
print(f"Total time: {benchmark_time:.2f} s")
|
||||
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
||||
benchmark_result = asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
best_of=args.best_of,
|
||||
use_beam_search=args.use_beam_search,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
))
|
||||
|
||||
# Compute the latency statistics.
|
||||
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
||||
print(f"Average latency: {avg_latency:.2f} s")
|
||||
avg_per_token_latency = np.mean([
|
||||
latency / (prompt_len + output_len)
|
||||
for prompt_len, output_len, latency in REQUEST_LATENCY
|
||||
])
|
||||
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
||||
avg_per_output_token_latency = np.mean(
|
||||
[latency / output_len for _, output_len, latency in REQUEST_LATENCY])
|
||||
print("Average latency per output token: "
|
||||
f"{avg_per_output_token_latency:.2f} s")
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
result_json = {}
|
||||
|
||||
# Setup
|
||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
result_json["date"] = current_dt
|
||||
result_json["backend"] = backend
|
||||
result_json["version"] = args.version
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
result_json["best_of"] = args.best_of
|
||||
result_json["use_beam_search"] = args.use_beam_search
|
||||
result_json["num_prompts"] = args.num_prompts
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (
|
||||
args.request_rate if args.request_rate < float("inf") else "inf")
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
||||
with open(file_name, "w") as outfile:
|
||||
json.dump(result_json, outfile)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the online serving throughput.")
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
default="vllm",
|
||||
choices=["vllm", "tgi"])
|
||||
parser.add_argument("--protocol",
|
||||
type=str,
|
||||
default="http",
|
||||
choices=["http", "https"])
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="vllm",
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="N/A",
|
||||
help="Version of the serving backend/engine.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Server or API base url if not using http host and port.",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--endpoint", type=str, default="/generate")
|
||||
parser.add_argument("--model", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--endpoint",
|
||||
type=str,
|
||||
default="/generate",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--tokenizer",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name or path of the tokenizer.")
|
||||
parser.add_argument("--best-of",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
type=str,
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default model tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best-of",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--request-rate",
|
||||
type=float,
|
||||
default=float("inf"),
|
||||
help="Number of requests per second. If this is inf, "
|
||||
"then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize "
|
||||
"the request arrival times.")
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
default=float("inf"),
|
||||
help="Number of requests per second. If this is inf, "
|
||||
"then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize "
|
||||
"the request arrival times.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Trust remote code from huggingface",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-tqdm",
|
||||
action="store_true",
|
||||
help="Specify to disbale tqdm progress bar.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-result",
|
||||
action="store_true",
|
||||
help="Specify to save benchmark results to a json file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -6,7 +6,7 @@ TOKENS=$2
|
||||
|
||||
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
||||
-v $PWD/data:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:0.8 \
|
||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
||||
--model-id $MODEL \
|
||||
--sharded false \
|
||||
--max-input-length 1024 \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user