[CI]Add regression tests to ensure the async engine generates metrics (#4524)
This commit is contained in:
parent
0d62fe58db
commit
5e401bce17
@ -1,4 +1,10 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
from prometheus_client import REGISTRY
|
||||||
|
|
||||||
|
from vllm import EngineArgs, LLMEngine
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
@ -68,3 +74,91 @@ def test_metric_counter_generation_tokens(
|
|||||||
assert vllm_generation_count == metric_count, (
|
assert vllm_generation_count == metric_count, (
|
||||||
f"generation token count: {vllm_generation_count!r}\n"
|
f"generation token count: {vllm_generation_count!r}\n"
|
||||||
f"metric: {metric_count!r}")
|
f"metric: {metric_count!r}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [4])
|
||||||
|
@pytest.mark.parametrize("disable_log_stats", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_engine_log_metrics_regression(
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
disable_log_stats: bool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Regression test ensuring async engine generates metrics
|
||||||
|
when disable_log_stats=False
|
||||||
|
(see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678)
|
||||||
|
"""
|
||||||
|
engine_args = AsyncEngineArgs(model=model,
|
||||||
|
dtype=dtype,
|
||||||
|
disable_log_stats=disable_log_stats)
|
||||||
|
async_engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
for i, prompt in enumerate(example_prompts):
|
||||||
|
results = async_engine.generate(
|
||||||
|
prompt,
|
||||||
|
SamplingParams(max_tokens=max_tokens),
|
||||||
|
f"request-id-{i}",
|
||||||
|
)
|
||||||
|
# Exhaust the async iterator to make the async engine work
|
||||||
|
async for _ in results:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert_metrics(async_engine.engine, disable_log_stats,
|
||||||
|
len(example_prompts))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [4])
|
||||||
|
@pytest.mark.parametrize("disable_log_stats", [True, False])
|
||||||
|
def test_engine_log_metrics_regression(
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
disable_log_stats: bool,
|
||||||
|
) -> None:
|
||||||
|
engine_args = EngineArgs(model=model,
|
||||||
|
dtype=dtype,
|
||||||
|
disable_log_stats=disable_log_stats)
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
for i, prompt in enumerate(example_prompts):
|
||||||
|
engine.add_request(
|
||||||
|
f"request-id-{i}",
|
||||||
|
prompt,
|
||||||
|
SamplingParams(max_tokens=max_tokens),
|
||||||
|
)
|
||||||
|
while engine.has_unfinished_requests():
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
assert_metrics(engine, disable_log_stats, len(example_prompts))
|
||||||
|
|
||||||
|
|
||||||
|
def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
|
||||||
|
num_requests: int) -> None:
|
||||||
|
if disable_log_stats:
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
_ = engine.stat_logger
|
||||||
|
else:
|
||||||
|
assert (engine.stat_logger
|
||||||
|
is not None), "engine.stat_logger should be set"
|
||||||
|
# Ensure the count bucket of request-level histogram metrics matches
|
||||||
|
# the number of requests as a simple sanity check to ensure metrics are
|
||||||
|
# generated
|
||||||
|
labels = {'model_name': engine.model_config.model}
|
||||||
|
request_histogram_metrics = [
|
||||||
|
"vllm:e2e_request_latency_seconds",
|
||||||
|
"vllm:request_prompt_tokens",
|
||||||
|
"vllm:request_generation_tokens",
|
||||||
|
"vllm:request_params_best_of",
|
||||||
|
"vllm:request_params_n",
|
||||||
|
]
|
||||||
|
for metric_name in request_histogram_metrics:
|
||||||
|
metric_value = REGISTRY.get_sample_value(f"{metric_name}_count",
|
||||||
|
labels)
|
||||||
|
assert (
|
||||||
|
metric_value == num_requests), "Metrics should be collected"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user