[ Bugfix ] Fix Prometheus Metrics With zeromq Frontend (#7279)
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
ab7165f2c7
commit
e3b318216d
@ -50,12 +50,3 @@ async def test_check_health(client: openai.AsyncOpenAI):
|
|||||||
response = requests.get(base_url + "/health")
|
response = requests.get(base_url + "/health")
|
||||||
|
|
||||||
assert response.status_code == HTTPStatus.OK
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_log_metrics(client: openai.AsyncOpenAI):
|
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
|
||||||
|
|
||||||
response = requests.get(base_url + "/metrics")
|
|
||||||
|
|
||||||
assert response.status_code == HTTPStatus.OK
|
|
||||||
|
|||||||
179
tests/entrypoints/openai/test_metrics.py
Normal file
179
tests/entrypoints/openai/test_metrics.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from prometheus_client.parser import text_string_to_metric_families
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def default_server_args():
|
||||||
|
return [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"1024",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module",
|
||||||
|
params=[
|
||||||
|
"",
|
||||||
|
"--enable-chunked-prefill",
|
||||||
|
"--disable-frontend-multiprocessing",
|
||||||
|
])
|
||||||
|
def client(default_server_args, request):
|
||||||
|
if request.param:
|
||||||
|
default_server_args.append(request.param)
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||||
|
yield remote_server.get_async_client()
|
||||||
|
|
||||||
|
|
||||||
|
_PROMPT = "Hello my name is Robert and I love magic"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"]
|
||||||
|
|
||||||
|
_NUM_REQUESTS = 10
|
||||||
|
_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT)
|
||||||
|
_NUM_GENERATION_TOKENS_PER_REQUEST = 10
|
||||||
|
|
||||||
|
# {metric_family: [(suffix, expected_value)]}
|
||||||
|
EXPECTED_VALUES = {
|
||||||
|
"vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)],
|
||||||
|
"vllm:time_per_output_token_seconds":
|
||||||
|
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
|
||||||
|
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
|
||||||
|
"vllm:request_prompt_tokens":
|
||||||
|
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
|
||||||
|
("_count", _NUM_REQUESTS)],
|
||||||
|
"vllm:request_generation_tokens":
|
||||||
|
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||||
|
("_count", _NUM_REQUESTS)],
|
||||||
|
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
|
||||||
|
"vllm:request_params_best_of": [("_count", _NUM_REQUESTS)],
|
||||||
|
"vllm:prompt_tokens": [("_total",
|
||||||
|
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
|
||||||
|
"vllm:generation_tokens":
|
||||||
|
[("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
|
||||||
|
"vllm:request_success": [("_total", _NUM_REQUESTS)],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_metrics_counts(client: openai.AsyncOpenAI):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
|
||||||
|
for _ in range(_NUM_REQUESTS):
|
||||||
|
# sending a request triggers the metrics to be logged.
|
||||||
|
await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=_TOKENIZED_PROMPT,
|
||||||
|
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)
|
||||||
|
|
||||||
|
response = requests.get(base_url + "/metrics")
|
||||||
|
print(response.text)
|
||||||
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
|
# Loop over all expected metric_families
|
||||||
|
for metric_family, suffix_values_list in EXPECTED_VALUES.items():
|
||||||
|
found_metric = False
|
||||||
|
|
||||||
|
# Check to see if the metric_family is found in the prom endpoint.
|
||||||
|
for family in text_string_to_metric_families(response.text):
|
||||||
|
if family.name == metric_family:
|
||||||
|
found_metric = True
|
||||||
|
|
||||||
|
# Check that each suffix is found in the prom endpoint.
|
||||||
|
for suffix, expected_value in suffix_values_list:
|
||||||
|
metric_name_w_suffix = f"{metric_family}{suffix}"
|
||||||
|
found_suffix = False
|
||||||
|
|
||||||
|
for sample in family.samples:
|
||||||
|
if sample.name == metric_name_w_suffix:
|
||||||
|
found_suffix = True
|
||||||
|
|
||||||
|
# For each suffix, value sure the value matches
|
||||||
|
# what we expect.
|
||||||
|
assert sample.value == expected_value, (
|
||||||
|
f"{metric_name_w_suffix} expected value of "
|
||||||
|
f"{expected_value} did not match found value "
|
||||||
|
f"{sample.value}")
|
||||||
|
break
|
||||||
|
assert found_suffix, (
|
||||||
|
f"Did not find {metric_name_w_suffix} in prom endpoint"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
assert found_metric, (f"Did not find {metric_family} in prom endpoint")
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_METRICS = [
|
||||||
|
"vllm:num_requests_running",
|
||||||
|
"vllm:num_requests_swapped",
|
||||||
|
"vllm:num_requests_waiting",
|
||||||
|
"vllm:gpu_cache_usage_perc",
|
||||||
|
"vllm:cpu_cache_usage_perc",
|
||||||
|
"vllm:time_to_first_token_seconds_sum",
|
||||||
|
"vllm:time_to_first_token_seconds_bucket",
|
||||||
|
"vllm:time_to_first_token_seconds_count",
|
||||||
|
"vllm:time_per_output_token_seconds_sum",
|
||||||
|
"vllm:time_per_output_token_seconds_bucket",
|
||||||
|
"vllm:time_per_output_token_seconds_count",
|
||||||
|
"vllm:e2e_request_latency_seconds_sum",
|
||||||
|
"vllm:e2e_request_latency_seconds_bucket",
|
||||||
|
"vllm:e2e_request_latency_seconds_count",
|
||||||
|
"vllm:request_prompt_tokens_sum",
|
||||||
|
"vllm:request_prompt_tokens_bucket",
|
||||||
|
"vllm:request_prompt_tokens_count",
|
||||||
|
"vllm:request_generation_tokens_sum",
|
||||||
|
"vllm:request_generation_tokens_bucket",
|
||||||
|
"vllm:request_generation_tokens_count",
|
||||||
|
"vllm:request_params_n_sum",
|
||||||
|
"vllm:request_params_n_bucket",
|
||||||
|
"vllm:request_params_n_count",
|
||||||
|
"vllm:request_params_best_of_sum",
|
||||||
|
"vllm:request_params_best_of_bucket",
|
||||||
|
"vllm:request_params_best_of_count",
|
||||||
|
"vllm:num_preemptions_total",
|
||||||
|
"vllm:prompt_tokens_total",
|
||||||
|
"vllm:generation_tokens_total",
|
||||||
|
"vllm:request_success_total",
|
||||||
|
"vllm:cache_config_info",
|
||||||
|
# labels in cache_config_info
|
||||||
|
"block_size",
|
||||||
|
"cache_dtype",
|
||||||
|
"cpu_offload_gb",
|
||||||
|
"enable_prefix_caching",
|
||||||
|
"gpu_memory_utilization",
|
||||||
|
"num_cpu_blocks",
|
||||||
|
"num_gpu_blocks",
|
||||||
|
"num_gpu_blocks_override",
|
||||||
|
"sliding_window",
|
||||||
|
"swap_space_bytes",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_metrics_exist(client: openai.AsyncOpenAI):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
|
||||||
|
# sending a request triggers the metrics to be logged.
|
||||||
|
await client.completions.create(model=MODEL_NAME,
|
||||||
|
prompt="Hello, my name is",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
response = requests.get(base_url + "/metrics")
|
||||||
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
|
for metric in EXPECTED_METRICS:
|
||||||
|
assert metric in response.text
|
||||||
@ -15,7 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.engine.async_timeout import asyncio_timeout
|
from vllm.engine.async_timeout import asyncio_timeout
|
||||||
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
||||||
PromptComponents)
|
PromptComponents)
|
||||||
from vllm.engine.metrics import StatLoggerBase
|
from vllm.engine.metrics_types import StatLoggerBase
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||||
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
||||||
|
|||||||
@ -16,8 +16,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
|||||||
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
||||||
SchedulerOutputs)
|
SchedulerOutputs)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
|
from vllm.engine.metrics_types import StatLoggerBase, Stats
|
||||||
StatLoggerBase, Stats)
|
|
||||||
from vllm.engine.output_processor.interfaces import (
|
from vllm.engine.output_processor.interfaces import (
|
||||||
SequenceGroupOutputProcessor)
|
SequenceGroupOutputProcessor)
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
@ -339,6 +338,13 @@ class LLMEngine:
|
|||||||
if stat_loggers is not None:
|
if stat_loggers is not None:
|
||||||
self.stat_loggers = stat_loggers
|
self.stat_loggers = stat_loggers
|
||||||
else:
|
else:
|
||||||
|
# Lazy import for prometheus multiprocessing.
|
||||||
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||||
|
# before prometheus_client is imported.
|
||||||
|
# See https://prometheus.github.io/client_python/multiprocess/
|
||||||
|
from vllm.engine.metrics import (LoggingStatLogger,
|
||||||
|
PrometheusStatLogger)
|
||||||
|
|
||||||
self.stat_loggers = {
|
self.stat_loggers = {
|
||||||
"logging":
|
"logging":
|
||||||
LoggingStatLogger(
|
LoggingStatLogger(
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from typing import Counter as CollectionsCounter
|
from typing import Counter as CollectionsCounter
|
||||||
from typing import Dict, List, Optional, Protocol, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import prometheus_client
|
import prometheus_client
|
||||||
|
|
||||||
|
from vllm.engine.metrics_types import (StatLoggerBase, Stats,
|
||||||
|
SupportsMetricsInfo)
|
||||||
from vllm.executor.ray_utils import ray
|
from vllm.executor.ray_utils import ray
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -29,41 +28,49 @@ prometheus_client.disable_created_metrics()
|
|||||||
|
|
||||||
# begin-metrics-definitions
|
# begin-metrics-definitions
|
||||||
class Metrics:
|
class Metrics:
|
||||||
|
"""
|
||||||
|
vLLM uses a multiprocessing-based frontend for the OpenAI server.
|
||||||
|
This means that we need to run prometheus_client in multiprocessing mode
|
||||||
|
See https://prometheus.github.io/client_python/multiprocess/ for more
|
||||||
|
details on limitations.
|
||||||
|
"""
|
||||||
labelname_finish_reason = "finished_reason"
|
labelname_finish_reason = "finished_reason"
|
||||||
_gauge_cls = prometheus_client.Gauge
|
_gauge_cls = prometheus_client.Gauge
|
||||||
_counter_cls = prometheus_client.Counter
|
_counter_cls = prometheus_client.Counter
|
||||||
_histogram_cls = prometheus_client.Histogram
|
_histogram_cls = prometheus_client.Histogram
|
||||||
|
|
||||||
def __init__(self, labelnames: List[str], max_model_len: int):
|
def __init__(self, labelnames: List[str], max_model_len: int):
|
||||||
# Unregister any existing vLLM collectors
|
# Unregister any existing vLLM collectors (for CI/CD)
|
||||||
self._unregister_vllm_metrics()
|
self._unregister_vllm_metrics()
|
||||||
|
|
||||||
# Config Information
|
|
||||||
self._create_info_cache_config()
|
|
||||||
|
|
||||||
# System stats
|
# System stats
|
||||||
# Scheduler State
|
# Scheduler State
|
||||||
self.gauge_scheduler_running = self._gauge_cls(
|
self.gauge_scheduler_running = self._gauge_cls(
|
||||||
name="vllm:num_requests_running",
|
name="vllm:num_requests_running",
|
||||||
documentation="Number of requests currently running on GPU.",
|
documentation="Number of requests currently running on GPU.",
|
||||||
labelnames=labelnames)
|
labelnames=labelnames,
|
||||||
|
multiprocess_mode="sum")
|
||||||
self.gauge_scheduler_waiting = self._gauge_cls(
|
self.gauge_scheduler_waiting = self._gauge_cls(
|
||||||
name="vllm:num_requests_waiting",
|
name="vllm:num_requests_waiting",
|
||||||
documentation="Number of requests waiting to be processed.",
|
documentation="Number of requests waiting to be processed.",
|
||||||
labelnames=labelnames)
|
labelnames=labelnames,
|
||||||
|
multiprocess_mode="sum")
|
||||||
self.gauge_scheduler_swapped = self._gauge_cls(
|
self.gauge_scheduler_swapped = self._gauge_cls(
|
||||||
name="vllm:num_requests_swapped",
|
name="vllm:num_requests_swapped",
|
||||||
documentation="Number of requests swapped to CPU.",
|
documentation="Number of requests swapped to CPU.",
|
||||||
labelnames=labelnames)
|
labelnames=labelnames,
|
||||||
|
multiprocess_mode="sum")
|
||||||
# KV Cache Usage in %
|
# KV Cache Usage in %
|
||||||
self.gauge_gpu_cache_usage = self._gauge_cls(
|
self.gauge_gpu_cache_usage = self._gauge_cls(
|
||||||
name="vllm:gpu_cache_usage_perc",
|
name="vllm:gpu_cache_usage_perc",
|
||||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||||
labelnames=labelnames)
|
labelnames=labelnames,
|
||||||
|
multiprocess_mode="sum")
|
||||||
self.gauge_cpu_cache_usage = self._gauge_cls(
|
self.gauge_cpu_cache_usage = self._gauge_cls(
|
||||||
name="vllm:cpu_cache_usage_perc",
|
name="vllm:cpu_cache_usage_perc",
|
||||||
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
||||||
labelnames=labelnames)
|
labelnames=labelnames,
|
||||||
|
multiprocess_mode="sum")
|
||||||
|
|
||||||
# Iteration stats
|
# Iteration stats
|
||||||
self.counter_num_preemption = self._counter_cls(
|
self.counter_num_preemption = self._counter_cls(
|
||||||
@ -137,11 +144,13 @@ class Metrics:
|
|||||||
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
|
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
|
||||||
name="vllm:spec_decode_draft_acceptance_rate",
|
name="vllm:spec_decode_draft_acceptance_rate",
|
||||||
documentation="Speulative token acceptance rate.",
|
documentation="Speulative token acceptance rate.",
|
||||||
labelnames=labelnames)
|
labelnames=labelnames,
|
||||||
|
multiprocess_mode="sum")
|
||||||
self.gauge_spec_decode_efficiency = self._gauge_cls(
|
self.gauge_spec_decode_efficiency = self._gauge_cls(
|
||||||
name="vllm:spec_decode_efficiency",
|
name="vllm:spec_decode_efficiency",
|
||||||
documentation="Speculative decoding system efficiency.",
|
documentation="Speculative decoding system efficiency.",
|
||||||
labelnames=labelnames)
|
labelnames=labelnames,
|
||||||
|
multiprocess_mode="sum")
|
||||||
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
|
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
|
||||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||||
documentation="Number of accepted tokens.",
|
documentation="Number of accepted tokens.",
|
||||||
@ -160,19 +169,18 @@ class Metrics:
|
|||||||
name="vllm:avg_prompt_throughput_toks_per_s",
|
name="vllm:avg_prompt_throughput_toks_per_s",
|
||||||
documentation="Average prefill throughput in tokens/s.",
|
documentation="Average prefill throughput in tokens/s.",
|
||||||
labelnames=labelnames,
|
labelnames=labelnames,
|
||||||
|
multiprocess_mode="sum",
|
||||||
)
|
)
|
||||||
# Deprecated in favor of vllm:generation_tokens_total
|
# Deprecated in favor of vllm:generation_tokens_total
|
||||||
self.gauge_avg_generation_throughput = self._gauge_cls(
|
self.gauge_avg_generation_throughput = self._gauge_cls(
|
||||||
name="vllm:avg_generation_throughput_toks_per_s",
|
name="vllm:avg_generation_throughput_toks_per_s",
|
||||||
documentation="Average generation throughput in tokens/s.",
|
documentation="Average generation throughput in tokens/s.",
|
||||||
labelnames=labelnames,
|
labelnames=labelnames,
|
||||||
|
multiprocess_mode="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_info_cache_config(self) -> None:
|
|
||||||
# Config Information
|
# end-metrics-definitions
|
||||||
self.info_cache_config = prometheus_client.Info(
|
|
||||||
name='vllm:cache_config',
|
|
||||||
documentation='information of cache_config')
|
|
||||||
|
|
||||||
def _unregister_vllm_metrics(self) -> None:
|
def _unregister_vllm_metrics(self) -> None:
|
||||||
for collector in list(prometheus_client.REGISTRY._collector_to_names):
|
for collector in list(prometheus_client.REGISTRY._collector_to_names):
|
||||||
@ -180,9 +188,6 @@ class Metrics:
|
|||||||
prometheus_client.REGISTRY.unregister(collector)
|
prometheus_client.REGISTRY.unregister(collector)
|
||||||
|
|
||||||
|
|
||||||
# end-metrics-definitions
|
|
||||||
|
|
||||||
|
|
||||||
class _RayGaugeWrapper:
|
class _RayGaugeWrapper:
|
||||||
"""Wraps around ray.util.metrics.Gauge to provide same API as
|
"""Wraps around ray.util.metrics.Gauge to provide same API as
|
||||||
prometheus_client.Gauge"""
|
prometheus_client.Gauge"""
|
||||||
@ -190,7 +195,9 @@ class _RayGaugeWrapper:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
name: str,
|
name: str,
|
||||||
documentation: str = "",
|
documentation: str = "",
|
||||||
labelnames: Optional[List[str]] = None):
|
labelnames: Optional[List[str]] = None,
|
||||||
|
multiprocess_mode: str = ""):
|
||||||
|
del multiprocess_mode
|
||||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||||
self._gauge = ray_metrics.Gauge(name=name,
|
self._gauge = ray_metrics.Gauge(name=name,
|
||||||
description=documentation,
|
description=documentation,
|
||||||
@ -268,10 +275,6 @@ class RayMetrics(Metrics):
|
|||||||
# No-op on purpose
|
# No-op on purpose
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _create_info_cache_config(self) -> None:
|
|
||||||
# No-op on purpose
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def build_1_2_5_buckets(max_value: int) -> List[int]:
|
def build_1_2_5_buckets(max_value: int) -> List[int]:
|
||||||
"""
|
"""
|
||||||
@ -295,46 +298,6 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
|
|||||||
exponent += 1
|
exponent += 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Stats:
|
|
||||||
"""Created by LLMEngine for use by StatLogger."""
|
|
||||||
now: float
|
|
||||||
|
|
||||||
# System stats (should have _sys suffix)
|
|
||||||
# Scheduler State
|
|
||||||
num_running_sys: int
|
|
||||||
num_waiting_sys: int
|
|
||||||
num_swapped_sys: int
|
|
||||||
# KV Cache Usage in %
|
|
||||||
gpu_cache_usage_sys: float
|
|
||||||
cpu_cache_usage_sys: float
|
|
||||||
|
|
||||||
# Iteration stats (should have _iter suffix)
|
|
||||||
num_prompt_tokens_iter: int
|
|
||||||
num_generation_tokens_iter: int
|
|
||||||
time_to_first_tokens_iter: List[float]
|
|
||||||
time_per_output_tokens_iter: List[float]
|
|
||||||
num_preemption_iter: int
|
|
||||||
|
|
||||||
# Request stats (should have _requests suffix)
|
|
||||||
# Latency
|
|
||||||
time_e2e_requests: List[float]
|
|
||||||
# Metadata
|
|
||||||
num_prompt_tokens_requests: List[int]
|
|
||||||
num_generation_tokens_requests: List[int]
|
|
||||||
best_of_requests: List[int]
|
|
||||||
n_requests: List[int]
|
|
||||||
finished_reason_requests: List[str]
|
|
||||||
|
|
||||||
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SupportsMetricsInfo(Protocol):
|
|
||||||
|
|
||||||
def metrics_info(self) -> Dict[str, str]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def local_interval_elapsed(now: float, last_log: float,
|
def local_interval_elapsed(now: float, last_log: float,
|
||||||
local_interval: float) -> bool:
|
local_interval: float) -> bool:
|
||||||
elapsed_time = now - last_log
|
elapsed_time = now - last_log
|
||||||
@ -346,38 +309,9 @@ def get_throughput(tracked_stats: List[int], now: float,
|
|||||||
return float(np.sum(tracked_stats) / (now - last_log))
|
return float(np.sum(tracked_stats) / (now - last_log))
|
||||||
|
|
||||||
|
|
||||||
class StatLoggerBase(ABC):
|
|
||||||
"""Base class for StatLogger."""
|
|
||||||
|
|
||||||
def __init__(self, local_interval: float) -> None:
|
|
||||||
# Tracked stats over current local logging interval.
|
|
||||||
self.num_prompt_tokens: List[int] = []
|
|
||||||
self.num_generation_tokens: List[int] = []
|
|
||||||
self.last_local_log = time.time()
|
|
||||||
self.local_interval = local_interval
|
|
||||||
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def log(self, stats: Stats) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def maybe_update_spec_decode_metrics(self, stats: Stats):
|
|
||||||
"""Save spec decode metrics (since they are unlikely
|
|
||||||
to be emitted at same time as log interval)."""
|
|
||||||
if stats.spec_decode_metrics is not None:
|
|
||||||
self.spec_decode_metrics = stats.spec_decode_metrics
|
|
||||||
|
|
||||||
|
|
||||||
class LoggingStatLogger(StatLoggerBase):
|
class LoggingStatLogger(StatLoggerBase):
|
||||||
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
|
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
|
||||||
|
|
||||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def log(self, stats: Stats) -> None:
|
def log(self, stats: Stats) -> None:
|
||||||
"""Called by LLMEngine.
|
"""Called by LLMEngine.
|
||||||
Logs to Stdout every self.local_interval seconds."""
|
Logs to Stdout every self.local_interval seconds."""
|
||||||
@ -440,10 +374,14 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
f"Number of draft tokens: {metrics.draft_tokens}, "
|
f"Number of draft tokens: {metrics.draft_tokens}, "
|
||||||
f"Number of emitted tokens: {metrics.emitted_tokens}.")
|
f"Number of emitted tokens: {metrics.emitted_tokens}.")
|
||||||
|
|
||||||
|
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PrometheusStatLogger(StatLoggerBase):
|
class PrometheusStatLogger(StatLoggerBase):
|
||||||
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
|
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
|
||||||
_metrics_cls = Metrics
|
_metrics_cls = Metrics
|
||||||
|
_gauge_cls = prometheus_client.Gauge
|
||||||
|
|
||||||
def __init__(self, local_interval: float, labels: Dict[str, str],
|
def __init__(self, local_interval: float, labels: Dict[str, str],
|
||||||
max_model_len: int) -> None:
|
max_model_len: int) -> None:
|
||||||
@ -453,10 +391,6 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
|
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
|
||||||
max_model_len=max_model_len)
|
max_model_len=max_model_len)
|
||||||
|
|
||||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
|
||||||
if type == "cache_config":
|
|
||||||
self.metrics.info_cache_config.info(obj.metrics_info())
|
|
||||||
|
|
||||||
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
||||||
# Convenience function for logging to gauge.
|
# Convenience function for logging to gauge.
|
||||||
gauge.labels(**self.labels).set(data)
|
gauge.labels(**self.labels).set(data)
|
||||||
@ -586,6 +520,19 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.last_local_log = stats.now
|
self.last_local_log = stats.now
|
||||||
self.spec_decode_metrics = None
|
self.spec_decode_metrics = None
|
||||||
|
|
||||||
|
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||||
|
# Info type metrics are syntactic sugar for a gauge permanently set to 1
|
||||||
|
# Since prometheus multiprocessing mode does not support Info, emulate
|
||||||
|
# info here with a gauge.
|
||||||
|
if type == "cache_config":
|
||||||
|
metrics_info = obj.metrics_info()
|
||||||
|
info_gauge = self._gauge_cls(
|
||||||
|
name="vllm:cache_config_info",
|
||||||
|
documentation="Information of the LLMEngine CacheConfig",
|
||||||
|
labelnames=metrics_info.keys(),
|
||||||
|
multiprocess_mode="mostrecent")
|
||||||
|
info_gauge.labels(**metrics_info).set(1)
|
||||||
|
|
||||||
|
|
||||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||||
|
|||||||
85
vllm/engine/metrics_types.py
Normal file
85
vllm/engine/metrics_types.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
These types are defined in this file to avoid importing vllm.engine.metrics
|
||||||
|
and therefore importing prometheus_client.
|
||||||
|
|
||||||
|
This is required due to usage of Prometheus multiprocess mode to enable
|
||||||
|
metrics after splitting out the uvicorn process from the engine process.
|
||||||
|
|
||||||
|
Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR
|
||||||
|
before prometheus_client is imported. Typically, this is done by setting
|
||||||
|
the env variable before launch, but since we are a library, we need to
|
||||||
|
do this in Python code and lazily import prometheus_client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Protocol
|
||||||
|
|
||||||
|
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Stats:
|
||||||
|
"""Created by LLMEngine for use by StatLogger."""
|
||||||
|
now: float
|
||||||
|
|
||||||
|
# System stats (should have _sys suffix)
|
||||||
|
# Scheduler State
|
||||||
|
num_running_sys: int
|
||||||
|
num_waiting_sys: int
|
||||||
|
num_swapped_sys: int
|
||||||
|
# KV Cache Usage in %
|
||||||
|
gpu_cache_usage_sys: float
|
||||||
|
cpu_cache_usage_sys: float
|
||||||
|
|
||||||
|
# Iteration stats (should have _iter suffix)
|
||||||
|
num_prompt_tokens_iter: int
|
||||||
|
num_generation_tokens_iter: int
|
||||||
|
time_to_first_tokens_iter: List[float]
|
||||||
|
time_per_output_tokens_iter: List[float]
|
||||||
|
num_preemption_iter: int
|
||||||
|
|
||||||
|
# Request stats (should have _requests suffix)
|
||||||
|
# Latency
|
||||||
|
time_e2e_requests: List[float]
|
||||||
|
# Metadata
|
||||||
|
num_prompt_tokens_requests: List[int]
|
||||||
|
num_generation_tokens_requests: List[int]
|
||||||
|
best_of_requests: List[int]
|
||||||
|
n_requests: List[int]
|
||||||
|
finished_reason_requests: List[str]
|
||||||
|
|
||||||
|
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SupportsMetricsInfo(Protocol):
|
||||||
|
|
||||||
|
def metrics_info(self) -> Dict[str, str]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class StatLoggerBase(ABC):
|
||||||
|
"""Base class for StatLogger."""
|
||||||
|
|
||||||
|
def __init__(self, local_interval: float) -> None:
|
||||||
|
# Tracked stats over current local logging interval.
|
||||||
|
self.num_prompt_tokens: List[int] = []
|
||||||
|
self.num_generation_tokens: List[int] = []
|
||||||
|
self.last_local_log = time.time()
|
||||||
|
self.local_interval = local_interval
|
||||||
|
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def log(self, stats: Stats) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def maybe_update_spec_decode_metrics(self, stats: Stats):
|
||||||
|
"""Save spec decode metrics (since they are unlikely
|
||||||
|
to be emitted at same time as log interval)."""
|
||||||
|
if stats.spec_decode_metrics is not None:
|
||||||
|
self.spec_decode_metrics = stats.spec_decode_metrics
|
||||||
@ -2,7 +2,9 @@ import asyncio
|
|||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import tempfile
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@ -12,7 +14,6 @@ from fastapi import APIRouter, FastAPI, Request
|
|||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
from prometheus_client import make_asgi_app
|
|
||||||
from starlette.routing import Mount
|
from starlette.routing import Mount
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -54,6 +55,7 @@ openai_serving_chat: OpenAIServingChat
|
|||||||
openai_serving_completion: OpenAIServingCompletion
|
openai_serving_completion: OpenAIServingCompletion
|
||||||
openai_serving_embedding: OpenAIServingEmbedding
|
openai_serving_embedding: OpenAIServingEmbedding
|
||||||
openai_serving_tokenization: OpenAIServingTokenization
|
openai_serving_tokenization: OpenAIServingTokenization
|
||||||
|
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||||
|
|
||||||
logger = init_logger('vllm.entrypoints.openai.api_server')
|
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||||
|
|
||||||
@ -109,6 +111,21 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
|||||||
|
|
||||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||||
else:
|
else:
|
||||||
|
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||||
|
# Make TemporaryDirectory for prometheus multiprocessing
|
||||||
|
# Note: global TemporaryDirectory will be automatically
|
||||||
|
# cleaned up upon exit.
|
||||||
|
global prometheus_multiproc_dir
|
||||||
|
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||||
|
os.environ[
|
||||||
|
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||||
|
"This directory must be wiped between vLLM runs or "
|
||||||
|
"you will find inaccurate metrics. Unset the variable "
|
||||||
|
"and vLLM will properly handle cleanup.")
|
||||||
|
|
||||||
# Select random path for IPC.
|
# Select random path for IPC.
|
||||||
rpc_path = get_open_zmq_ipc_path()
|
rpc_path = get_open_zmq_ipc_path()
|
||||||
logger.info("Multiprocessing frontend to use %s for RPC Path.",
|
logger.info("Multiprocessing frontend to use %s for RPC Path.",
|
||||||
@ -149,13 +166,38 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
|||||||
# Wait for server process to join
|
# Wait for server process to join
|
||||||
rpc_server_process.join()
|
rpc_server_process.join()
|
||||||
|
|
||||||
|
# Lazy import for prometheus multiprocessing.
|
||||||
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||||
|
# before prometheus_client is imported.
|
||||||
|
# See https://prometheus.github.io/client_python/multiprocess/
|
||||||
|
from prometheus_client import multiprocess
|
||||||
|
multiprocess.mark_process_dead(rpc_server_process.pid)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def mount_metrics(app: FastAPI):
|
def mount_metrics(app: FastAPI):
|
||||||
# Add prometheus asgi middleware to route /metrics requests
|
# Lazy import for prometheus multiprocessing.
|
||||||
metrics_route = Mount("/metrics", make_asgi_app())
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||||
|
# before prometheus_client is imported.
|
||||||
|
# See https://prometheus.github.io/client_python/multiprocess/
|
||||||
|
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||||
|
multiprocess)
|
||||||
|
|
||||||
|
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||||
|
if prometheus_multiproc_dir_path is not None:
|
||||||
|
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||||
|
prometheus_multiproc_dir_path)
|
||||||
|
registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(registry)
|
||||||
|
|
||||||
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
|
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||||
|
else:
|
||||||
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
|
metrics_route = Mount("/metrics", make_asgi_app())
|
||||||
|
|
||||||
# Workaround for 307 Redirect for /metrics
|
# Workaround for 307 Redirect for /metrics
|
||||||
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
|
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
|
||||||
app.routes.append(metrics_route)
|
app.routes.append(metrics_route)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user