[ Bugfix ] Fix Prometheus Metrics With zeromq Frontend (#7279)

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Robert Shaw 2024-08-18 16:19:48 -04:00 committed by GitHub
parent ab7165f2c7
commit e3b318216d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 366 additions and 116 deletions

View File

@ -50,12 +50,3 @@ async def test_check_health(client: openai.AsyncOpenAI):
response = requests.get(base_url + "/health")
assert response.status_code == HTTPStatus.OK
@pytest.mark.asyncio
async def test_log_metrics(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")
response = requests.get(base_url + "/metrics")
assert response.status_code == HTTPStatus.OK

View File

@ -0,0 +1,179 @@
from http import HTTPStatus
import openai
import pytest
import requests
from prometheus_client.parser import text_string_to_metric_families
from transformers import AutoTokenizer
from ...utils import RemoteOpenAIServer
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"1024",
"--enforce-eager",
"--max-num-seqs",
"128",
]
@pytest.fixture(scope="module",
params=[
"",
"--enable-chunked-prefill",
"--disable-frontend-multiprocessing",
])
def client(default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server.get_async_client()
_PROMPT = "Hello my name is Robert and I love magic"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"]
_NUM_REQUESTS = 10
_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT)
_NUM_GENERATION_TOKENS_PER_REQUEST = 10
# {metric_family: [(suffix, expected_value)]}
EXPECTED_VALUES = {
"vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)],
"vllm:time_per_output_token_seconds":
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prompt_tokens":
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)],
"vllm:request_generation_tokens":
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)],
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
"vllm:request_params_best_of": [("_count", _NUM_REQUESTS)],
"vllm:prompt_tokens": [("_total",
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:generation_tokens":
[("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:request_success": [("_total", _NUM_REQUESTS)],
}
@pytest.mark.asyncio
async def test_metrics_counts(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")
for _ in range(_NUM_REQUESTS):
# sending a request triggers the metrics to be logged.
await client.completions.create(
model=MODEL_NAME,
prompt=_TOKENIZED_PROMPT,
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)
response = requests.get(base_url + "/metrics")
print(response.text)
assert response.status_code == HTTPStatus.OK
# Loop over all expected metric_families
for metric_family, suffix_values_list in EXPECTED_VALUES.items():
found_metric = False
# Check to see if the metric_family is found in the prom endpoint.
for family in text_string_to_metric_families(response.text):
if family.name == metric_family:
found_metric = True
# Check that each suffix is found in the prom endpoint.
for suffix, expected_value in suffix_values_list:
metric_name_w_suffix = f"{metric_family}{suffix}"
found_suffix = False
for sample in family.samples:
if sample.name == metric_name_w_suffix:
found_suffix = True
# For each suffix, value sure the value matches
# what we expect.
assert sample.value == expected_value, (
f"{metric_name_w_suffix} expected value of "
f"{expected_value} did not match found value "
f"{sample.value}")
break
assert found_suffix, (
f"Did not find {metric_name_w_suffix} in prom endpoint"
)
break
assert found_metric, (f"Did not find {metric_family} in prom endpoint")
EXPECTED_METRICS = [
"vllm:num_requests_running",
"vllm:num_requests_swapped",
"vllm:num_requests_waiting",
"vllm:gpu_cache_usage_perc",
"vllm:cpu_cache_usage_perc",
"vllm:time_to_first_token_seconds_sum",
"vllm:time_to_first_token_seconds_bucket",
"vllm:time_to_first_token_seconds_count",
"vllm:time_per_output_token_seconds_sum",
"vllm:time_per_output_token_seconds_bucket",
"vllm:time_per_output_token_seconds_count",
"vllm:e2e_request_latency_seconds_sum",
"vllm:e2e_request_latency_seconds_bucket",
"vllm:e2e_request_latency_seconds_count",
"vllm:request_prompt_tokens_sum",
"vllm:request_prompt_tokens_bucket",
"vllm:request_prompt_tokens_count",
"vllm:request_generation_tokens_sum",
"vllm:request_generation_tokens_bucket",
"vllm:request_generation_tokens_count",
"vllm:request_params_n_sum",
"vllm:request_params_n_bucket",
"vllm:request_params_n_count",
"vllm:request_params_best_of_sum",
"vllm:request_params_best_of_bucket",
"vllm:request_params_best_of_count",
"vllm:num_preemptions_total",
"vllm:prompt_tokens_total",
"vllm:generation_tokens_total",
"vllm:request_success_total",
"vllm:cache_config_info",
# labels in cache_config_info
"block_size",
"cache_dtype",
"cpu_offload_gb",
"enable_prefix_caching",
"gpu_memory_utilization",
"num_cpu_blocks",
"num_gpu_blocks",
"num_gpu_blocks_override",
"sliding_window",
"swap_space_bytes",
]
@pytest.mark.asyncio
async def test_metrics_exist(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")
# sending a request triggers the metrics to be logged.
await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
response = requests.get(base_url + "/metrics")
assert response.status_code == HTTPStatus.OK
for metric in EXPECTED_METRICS:
assert metric in response.text

View File

@ -15,7 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents)
from vllm.engine.metrics import StatLoggerBase
from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,

View File

@ -16,8 +16,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase, Stats)
from vllm.engine.metrics_types import StatLoggerBase, Stats
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
@ -339,6 +338,13 @@ class LLMEngine:
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import (LoggingStatLogger,
PrometheusStatLogger)
self.stat_loggers = {
"logging":
LoggingStatLogger(

View File

@ -1,13 +1,12 @@
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Protocol, Union
from typing import Dict, List, Optional, Union
import numpy as np
import prometheus_client
from vllm.engine.metrics_types import (StatLoggerBase, Stats,
SupportsMetricsInfo)
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
@ -29,41 +28,49 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions
class Metrics:
"""
vLLM uses a multiprocessing-based frontend for the OpenAI server.
This means that we need to run prometheus_client in multiprocessing mode
See https://prometheus.github.io/client_python/multiprocess/ for more
details on limitations.
"""
labelname_finish_reason = "finished_reason"
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors
# Unregister any existing vLLM collectors (for CI/CD)
self._unregister_vllm_metrics()
# Config Information
self._create_info_cache_config()
# System stats
# Scheduler State
self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_scheduler_swapped = self._gauge_cls(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
# KV Cache Usage in %
self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_cpu_cache_usage = self._gauge_cls(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
# Iteration stats
self.counter_num_preemption = self._counter_cls(
@ -137,11 +144,13 @@ class Metrics:
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_spec_decode_efficiency = self._gauge_cls(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames)
labelnames=labelnames,
multiprocess_mode="sum")
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
@ -160,19 +169,18 @@ class Metrics:
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
multiprocess_mode="sum",
)
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = self._gauge_cls(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
multiprocess_mode="sum",
)
def _create_info_cache_config(self) -> None:
# Config Information
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
# end-metrics-definitions
def _unregister_vllm_metrics(self) -> None:
for collector in list(prometheus_client.REGISTRY._collector_to_names):
@ -180,9 +188,6 @@ class Metrics:
prometheus_client.REGISTRY.unregister(collector)
# end-metrics-definitions
class _RayGaugeWrapper:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
@ -190,7 +195,9 @@ class _RayGaugeWrapper:
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames: Optional[List[str]] = None,
multiprocess_mode: str = ""):
del multiprocess_mode
labelnames_tuple = tuple(labelnames) if labelnames else None
self._gauge = ray_metrics.Gauge(name=name,
description=documentation,
@ -268,10 +275,6 @@ class RayMetrics(Metrics):
# No-op on purpose
pass
def _create_info_cache_config(self) -> None:
# No-op on purpose
pass
def build_1_2_5_buckets(max_value: int) -> List[int]:
"""
@ -295,46 +298,6 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
exponent += 1
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
def local_interval_elapsed(now: float, last_log: float,
local_interval: float) -> bool:
elapsed_time = now - last_log
@ -346,38 +309,9 @@ def get_throughput(tracked_stats: List[int], now: float,
return float(np.sum(tracked_stats) / (now - last_log))
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics
class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to Stdout every self.local_interval seconds."""
@ -440,10 +374,14 @@ class LoggingStatLogger(StatLoggerBase):
f"Number of draft tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens: {metrics.emitted_tokens}.")
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
class PrometheusStatLogger(StatLoggerBase):
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
_metrics_cls = Metrics
_gauge_cls = prometheus_client.Gauge
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
@ -453,10 +391,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
@ -586,6 +520,19 @@ class PrometheusStatLogger(StatLoggerBase):
self.last_local_log = stats.now
self.spec_decode_metrics = None
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Since prometheus multiprocessing mode does not support Info, emulate
# info here with a gauge.
if type == "cache_config":
metrics_info = obj.metrics_info()
info_gauge = self._gauge_cls(
name="vllm:cache_config_info",
documentation="Information of the LLMEngine CacheConfig",
labelnames=metrics_info.keys(),
multiprocess_mode="mostrecent")
info_gauge.labels(**metrics_info).set(1)
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""

View File

@ -0,0 +1,85 @@
"""
These types are defined in this file to avoid importing vllm.engine.metrics
and therefore importing prometheus_client.
This is required due to usage of Prometheus multiprocess mode to enable
metrics after splitting out the uvicorn process from the engine process.
Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR
before prometheus_client is imported. Typically, this is done by setting
the env variable before launch, but since we are a library, we need to
do this in Python code and lazily import prometheus_client.
"""
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics

View File

@ -2,7 +2,9 @@ import asyncio
import importlib
import inspect
import multiprocessing
import os
import re
import tempfile
from argparse import Namespace
from contextlib import asynccontextmanager
from http import HTTPStatus
@ -12,7 +14,6 @@ from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.routing import Mount
import vllm.envs as envs
@ -54,6 +55,7 @@ openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
prometheus_multiproc_dir: tempfile.TemporaryDirectory
logger = init_logger('vllm.entrypoints.openai.api_server')
@ -109,6 +111,21 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global prometheus_multiproc_dir
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ[
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
# Select random path for IPC.
rpc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for RPC Path.",
@ -149,13 +166,38 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Wait for server process to join
rpc_server_process.join()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(rpc_server_process.pid)
router = APIRouter()
def mount_metrics(app: FastAPI):
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(metrics_route)