[ Bugfix ] Fix Prometheus Metrics With zeromq Frontend (#7279)

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Robert Shaw 2024-08-18 16:19:48 -04:00 committed by GitHub
parent ab7165f2c7
commit e3b318216d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 366 additions and 116 deletions

View File

@ -50,12 +50,3 @@ async def test_check_health(client: openai.AsyncOpenAI):
response = requests.get(base_url + "/health") response = requests.get(base_url + "/health")
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
@pytest.mark.asyncio
async def test_log_metrics(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")
response = requests.get(base_url + "/metrics")
assert response.status_code == HTTPStatus.OK

View File

@ -0,0 +1,179 @@
from http import HTTPStatus
import openai
import pytest
import requests
from prometheus_client.parser import text_string_to_metric_families
from transformers import AutoTokenizer
from ...utils import RemoteOpenAIServer
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"1024",
"--enforce-eager",
"--max-num-seqs",
"128",
]
@pytest.fixture(scope="module",
params=[
"",
"--enable-chunked-prefill",
"--disable-frontend-multiprocessing",
])
def client(default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server.get_async_client()
_PROMPT = "Hello my name is Robert and I love magic"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"]
_NUM_REQUESTS = 10
_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT)
_NUM_GENERATION_TOKENS_PER_REQUEST = 10
# {metric_family: [(suffix, expected_value)]}
EXPECTED_VALUES = {
"vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)],
"vllm:time_per_output_token_seconds":
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prompt_tokens":
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)],
"vllm:request_generation_tokens":
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)],
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
"vllm:request_params_best_of": [("_count", _NUM_REQUESTS)],
"vllm:prompt_tokens": [("_total",
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:generation_tokens":
[("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:request_success": [("_total", _NUM_REQUESTS)],
}
@pytest.mark.asyncio
async def test_metrics_counts(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")
for _ in range(_NUM_REQUESTS):
# sending a request triggers the metrics to be logged.
await client.completions.create(
model=MODEL_NAME,
prompt=_TOKENIZED_PROMPT,
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)
response = requests.get(base_url + "/metrics")
print(response.text)
assert response.status_code == HTTPStatus.OK
# Loop over all expected metric_families
for metric_family, suffix_values_list in EXPECTED_VALUES.items():
found_metric = False
# Check to see if the metric_family is found in the prom endpoint.
for family in text_string_to_metric_families(response.text):
if family.name == metric_family:
found_metric = True
# Check that each suffix is found in the prom endpoint.
for suffix, expected_value in suffix_values_list:
metric_name_w_suffix = f"{metric_family}{suffix}"
found_suffix = False
for sample in family.samples:
if sample.name == metric_name_w_suffix:
found_suffix = True
# For each suffix, value sure the value matches
# what we expect.
assert sample.value == expected_value, (
f"{metric_name_w_suffix} expected value of "
f"{expected_value} did not match found value "
f"{sample.value}")
break
assert found_suffix, (
f"Did not find {metric_name_w_suffix} in prom endpoint"
)
break
assert found_metric, (f"Did not find {metric_family} in prom endpoint")
EXPECTED_METRICS = [
"vllm:num_requests_running",
"vllm:num_requests_swapped",
"vllm:num_requests_waiting",
"vllm:gpu_cache_usage_perc",
"vllm:cpu_cache_usage_perc",
"vllm:time_to_first_token_seconds_sum",
"vllm:time_to_first_token_seconds_bucket",
"vllm:time_to_first_token_seconds_count",
"vllm:time_per_output_token_seconds_sum",
"vllm:time_per_output_token_seconds_bucket",
"vllm:time_per_output_token_seconds_count",
"vllm:e2e_request_latency_seconds_sum",
"vllm:e2e_request_latency_seconds_bucket",
"vllm:e2e_request_latency_seconds_count",
"vllm:request_prompt_tokens_sum",
"vllm:request_prompt_tokens_bucket",
"vllm:request_prompt_tokens_count",
"vllm:request_generation_tokens_sum",
"vllm:request_generation_tokens_bucket",
"vllm:request_generation_tokens_count",
"vllm:request_params_n_sum",
"vllm:request_params_n_bucket",
"vllm:request_params_n_count",
"vllm:request_params_best_of_sum",
"vllm:request_params_best_of_bucket",
"vllm:request_params_best_of_count",
"vllm:num_preemptions_total",
"vllm:prompt_tokens_total",
"vllm:generation_tokens_total",
"vllm:request_success_total",
"vllm:cache_config_info",
# labels in cache_config_info
"block_size",
"cache_dtype",
"cpu_offload_gb",
"enable_prefix_caching",
"gpu_memory_utilization",
"num_cpu_blocks",
"num_gpu_blocks",
"num_gpu_blocks_override",
"sliding_window",
"swap_space_bytes",
]
@pytest.mark.asyncio
async def test_metrics_exist(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")
# sending a request triggers the metrics to be logged.
await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
response = requests.get(base_url + "/metrics")
assert response.status_code == HTTPStatus.OK
for metric in EXPECTED_METRICS:
assert metric in response.text

View File

@ -15,7 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents) PromptComponents)
from vllm.engine.metrics import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,

View File

@ -16,8 +16,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs) SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger, from vllm.engine.metrics_types import StatLoggerBase, Stats
StatLoggerBase, Stats)
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
@ -339,6 +338,13 @@ class LLMEngine:
if stat_loggers is not None: if stat_loggers is not None:
self.stat_loggers = stat_loggers self.stat_loggers = stat_loggers
else: else:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import (LoggingStatLogger,
PrometheusStatLogger)
self.stat_loggers = { self.stat_loggers = {
"logging": "logging":
LoggingStatLogger( LoggingStatLogger(

View File

@ -1,13 +1,12 @@
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Protocol, Union from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
import prometheus_client import prometheus_client
from vllm.engine.metrics_types import (StatLoggerBase, Stats,
SupportsMetricsInfo)
from vllm.executor.ray_utils import ray from vllm.executor.ray_utils import ray
from vllm.logger import init_logger from vllm.logger import init_logger
@ -29,41 +28,49 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions # begin-metrics-definitions
class Metrics: class Metrics:
"""
vLLM uses a multiprocessing-based frontend for the OpenAI server.
This means that we need to run prometheus_client in multiprocessing mode
See https://prometheus.github.io/client_python/multiprocess/ for more
details on limitations.
"""
labelname_finish_reason = "finished_reason" labelname_finish_reason = "finished_reason"
_gauge_cls = prometheus_client.Gauge _gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter _counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram _histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], max_model_len: int): def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors # Unregister any existing vLLM collectors (for CI/CD)
self._unregister_vllm_metrics() self._unregister_vllm_metrics()
# Config Information
self._create_info_cache_config()
# System stats # System stats
# Scheduler State # Scheduler State
self.gauge_scheduler_running = self._gauge_cls( self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running", name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.", documentation="Number of requests currently running on GPU.",
labelnames=labelnames) labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_scheduler_waiting = self._gauge_cls( self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting", name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.", documentation="Number of requests waiting to be processed.",
labelnames=labelnames) labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_scheduler_swapped = self._gauge_cls( self.gauge_scheduler_swapped = self._gauge_cls(
name="vllm:num_requests_swapped", name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.", documentation="Number of requests swapped to CPU.",
labelnames=labelnames) labelnames=labelnames,
multiprocess_mode="sum")
# KV Cache Usage in % # KV Cache Usage in %
self.gauge_gpu_cache_usage = self._gauge_cls( self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc", name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames) labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_cpu_cache_usage = self._gauge_cls( self.gauge_cpu_cache_usage = self._gauge_cls(
name="vllm:cpu_cache_usage_perc", name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.", documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames) labelnames=labelnames,
multiprocess_mode="sum")
# Iteration stats # Iteration stats
self.counter_num_preemption = self._counter_cls( self.counter_num_preemption = self._counter_cls(
@ -137,11 +144,13 @@ class Metrics:
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate", name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.", documentation="Speulative token acceptance rate.",
labelnames=labelnames) labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_spec_decode_efficiency = self._gauge_cls( self.gauge_spec_decode_efficiency = self._gauge_cls(
name="vllm:spec_decode_efficiency", name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.", documentation="Speculative decoding system efficiency.",
labelnames=labelnames) labelnames=labelnames,
multiprocess_mode="sum")
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total", name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.", documentation="Number of accepted tokens.",
@ -160,19 +169,18 @@ class Metrics:
name="vllm:avg_prompt_throughput_toks_per_s", name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.", documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames, labelnames=labelnames,
multiprocess_mode="sum",
) )
# Deprecated in favor of vllm:generation_tokens_total # Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = self._gauge_cls( self.gauge_avg_generation_throughput = self._gauge_cls(
name="vllm:avg_generation_throughput_toks_per_s", name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.", documentation="Average generation throughput in tokens/s.",
labelnames=labelnames, labelnames=labelnames,
multiprocess_mode="sum",
) )
def _create_info_cache_config(self) -> None:
# Config Information # end-metrics-definitions
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
def _unregister_vllm_metrics(self) -> None: def _unregister_vllm_metrics(self) -> None:
for collector in list(prometheus_client.REGISTRY._collector_to_names): for collector in list(prometheus_client.REGISTRY._collector_to_names):
@ -180,9 +188,6 @@ class Metrics:
prometheus_client.REGISTRY.unregister(collector) prometheus_client.REGISTRY.unregister(collector)
# end-metrics-definitions
class _RayGaugeWrapper: class _RayGaugeWrapper:
"""Wraps around ray.util.metrics.Gauge to provide same API as """Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge""" prometheus_client.Gauge"""
@ -190,7 +195,9 @@ class _RayGaugeWrapper:
def __init__(self, def __init__(self,
name: str, name: str,
documentation: str = "", documentation: str = "",
labelnames: Optional[List[str]] = None): labelnames: Optional[List[str]] = None,
multiprocess_mode: str = ""):
del multiprocess_mode
labelnames_tuple = tuple(labelnames) if labelnames else None labelnames_tuple = tuple(labelnames) if labelnames else None
self._gauge = ray_metrics.Gauge(name=name, self._gauge = ray_metrics.Gauge(name=name,
description=documentation, description=documentation,
@ -268,10 +275,6 @@ class RayMetrics(Metrics):
# No-op on purpose # No-op on purpose
pass pass
def _create_info_cache_config(self) -> None:
# No-op on purpose
pass
def build_1_2_5_buckets(max_value: int) -> List[int]: def build_1_2_5_buckets(max_value: int) -> List[int]:
""" """
@ -295,46 +298,6 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
exponent += 1 exponent += 1
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
def local_interval_elapsed(now: float, last_log: float, def local_interval_elapsed(now: float, last_log: float,
local_interval: float) -> bool: local_interval: float) -> bool:
elapsed_time = now - last_log elapsed_time = now - last_log
@ -346,38 +309,9 @@ def get_throughput(tracked_stats: List[int], now: float,
return float(np.sum(tracked_stats) / (now - last_log)) return float(np.sum(tracked_stats) / (now - last_log))
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics
class LoggingStatLogger(StatLoggerBase): class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout.""" """LoggingStatLogger is used in LLMEngine to log to Stdout."""
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def log(self, stats: Stats) -> None: def log(self, stats: Stats) -> None:
"""Called by LLMEngine. """Called by LLMEngine.
Logs to Stdout every self.local_interval seconds.""" Logs to Stdout every self.local_interval seconds."""
@ -440,10 +374,14 @@ class LoggingStatLogger(StatLoggerBase):
f"Number of draft tokens: {metrics.draft_tokens}, " f"Number of draft tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens: {metrics.emitted_tokens}.") f"Number of emitted tokens: {metrics.emitted_tokens}.")
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
class PrometheusStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase):
"""PrometheusStatLogger is used LLMEngine to log to Promethus.""" """PrometheusStatLogger is used LLMEngine to log to Promethus."""
_metrics_cls = Metrics _metrics_cls = Metrics
_gauge_cls = prometheus_client.Gauge
def __init__(self, local_interval: float, labels: Dict[str, str], def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None: max_model_len: int) -> None:
@ -453,10 +391,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.metrics = self._metrics_cls(labelnames=list(labels.keys()), self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
max_model_len=max_model_len) max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
def _log_gauge(self, gauge, data: Union[int, float]) -> None: def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge. # Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data) gauge.labels(**self.labels).set(data)
@ -586,6 +520,19 @@ class PrometheusStatLogger(StatLoggerBase):
self.last_local_log = stats.now self.last_local_log = stats.now
self.spec_decode_metrics = None self.spec_decode_metrics = None
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Since prometheus multiprocessing mode does not support Info, emulate
# info here with a gauge.
if type == "cache_config":
metrics_info = obj.metrics_info()
info_gauge = self._gauge_cls(
name="vllm:cache_config_info",
documentation="Information of the LLMEngine CacheConfig",
labelnames=metrics_info.keys(),
multiprocess_mode="mostrecent")
info_gauge.labels(**metrics_info).set(1)
class RayPrometheusStatLogger(PrometheusStatLogger): class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead.""" """RayPrometheusStatLogger uses Ray metrics instead."""

View File

@ -0,0 +1,85 @@
"""
These types are defined in this file to avoid importing vllm.engine.metrics
and therefore importing prometheus_client.
This is required due to usage of Prometheus multiprocess mode to enable
metrics after splitting out the uvicorn process from the engine process.
Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR
before prometheus_client is imported. Typically, this is done by setting
the env variable before launch, but since we are a library, we need to
do this in Python code and lazily import prometheus_client.
"""
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics

View File

@ -2,7 +2,9 @@ import asyncio
import importlib import importlib
import inspect import inspect
import multiprocessing import multiprocessing
import os
import re import re
import tempfile
from argparse import Namespace from argparse import Namespace
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from http import HTTPStatus from http import HTTPStatus
@ -12,7 +14,6 @@ from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.routing import Mount from starlette.routing import Mount
import vllm.envs as envs import vllm.envs as envs
@ -54,6 +55,7 @@ openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization openai_serving_tokenization: OpenAIServingTokenization
prometheus_multiproc_dir: tempfile.TemporaryDirectory
logger = init_logger('vllm.entrypoints.openai.api_server') logger = init_logger('vllm.entrypoints.openai.api_server')
@ -109,6 +111,21 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Otherwise, use the multiprocessing AsyncLLMEngine. # Otherwise, use the multiprocessing AsyncLLMEngine.
else: else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global prometheus_multiproc_dir
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ[
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
# Select random path for IPC. # Select random path for IPC.
rpc_path = get_open_zmq_ipc_path() rpc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for RPC Path.", logger.info("Multiprocessing frontend to use %s for RPC Path.",
@ -149,13 +166,38 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Wait for server process to join # Wait for server process to join
rpc_server_process.join() rpc_server_process.join()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(rpc_server_process.pid)
router = APIRouter() router = APIRouter()
def mount_metrics(app: FastAPI): def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
# Add prometheus asgi middleware to route /metrics requests # Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app()) metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics # Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$') metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(metrics_route) app.routes.append(metrics_route)