[Benchmark] Add --async-engine option to benchmark_throughput.py (#7964)

This commit is contained in:
Nick Hill 2024-09-03 17:57:41 -07:00 committed by GitHub
parent 2188a60c7e
commit d4db9f53c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 143 additions and 19 deletions

View File

@ -6,13 +6,16 @@ import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import uvloop
from tqdm import tqdm from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser, merge_async_iterators
def sample_requests( def sample_requests(
@ -135,6 +138,93 @@ def run_vllm(
return end - start return end - start
async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
disable_frontend_multiprocessing: bool = False,
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
engine_use_ray=False,
disable_log_requests=True,
)
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start
def run_hf( def run_hf(
requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]],
model: str, model: str,
@ -230,7 +320,7 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( run_args = [
requests, args.model, args.tokenizer, args.quantization, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype, args.max_model_len,
@ -240,7 +330,14 @@ def main(args: argparse.Namespace):
args.max_num_batched_tokens, args.distributed_executor_backend, args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps, args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format, args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc) args.disable_async_output_proc
]
if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time = run_vllm(*run_args)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -426,6 +523,14 @@ if __name__ == "__main__":
action='store_true', action='store_true',
default=False, default=False,
help="Disable async output processor for vLLM backend.") help="Disable async output processor for vLLM backend.")
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model

View File

@ -67,7 +67,7 @@ _running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool, def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: str) -> bool: quantization: Optional[str]) -> bool:
return ModelConfig(model=model_name, return ModelConfig(model=model_name,
tokenizer=model_name, tokenizer=model_name,
tokenizer_mode="auto", tokenizer_mode="auto",
@ -96,13 +96,6 @@ async def lifespan(app: FastAPI):
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client( async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Context manager to handle async_engine_client lifecycle # Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
@ -112,14 +105,37 @@ async def build_async_engine_client(
# Backend itself still global for the silly lil' health handler # Backend itself still global for the silly lil' health handler
global async_engine_client global async_engine_client
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
async_engine_client = engine # type: ignore[assignment]
yield engine
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# If manually triggered or embedding model, use AsyncLLMEngine in process. # If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC. # TODO: support embedding model via RPC.
if (model_is_embedding(args.model, args.trust_remote_code, if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
args.quantization) engine_args.quantization)
or args.disable_frontend_multiprocessing): or disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args( engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client try:
yield engine_client
finally:
engine_client.shutdown_background_loop()
return return
# Otherwise, use the multiprocessing AsyncLLMEngine. # Otherwise, use the multiprocessing AsyncLLMEngine.
@ -148,7 +164,6 @@ async def build_async_engine_client(
# NOTE: Actually, this is not true yet. We still need to support # NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above) # embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path) rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
# Start RPCServer in separate process (holds the AsyncLLMEngine). # Start RPCServer in separate process (holds the AsyncLLMEngine).
context = multiprocessing.get_context("spawn") context = multiprocessing.get_context("spawn")
@ -174,7 +189,7 @@ async def build_async_engine_client(
yield None yield None
return return
yield async_engine_client yield rpc_client # type: ignore[misc]
finally: finally:
# Ensure rpc server process was terminated # Ensure rpc server process was terminated
rpc_server_process.terminate() rpc_server_process.terminate()

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import cloudpickle import cloudpickle
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket from zmq.asyncio import Socket
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
@ -214,6 +215,7 @@ class AsyncEngineRPCClient:
# Await the data from the Server. # Await the data from the Server.
frame = await socket.recv(copy=False) frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
data = pickle.loads(frame.buffer) data = pickle.loads(frame.buffer)
if isinstance(data, Exception): if isinstance(data, Exception):
@ -247,6 +249,7 @@ class AsyncEngineRPCClient:
f"{self._data_timeout} ms") f"{self._data_timeout} ms")
frame = await socket.recv(copy=False) frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
return pickle.loads(frame.buffer) return pickle.loads(frame.buffer)
# Make a new socket connection. # Make a new socket connection.
@ -395,6 +398,7 @@ class AsyncEngineRPCClient:
# Stream back the results from the RPC Server. # Stream back the results from the RPC Server.
while not finished: while not finished:
message = await socket.recv(copy=False) message = await socket.recv(copy=False)
assert isinstance(message, Frame)
request_output = pickle.loads(message.buffer) request_output = pickle.loads(message.buffer)
if isinstance(request_output, Exception): if isinstance(request_output, Exception):