[V1] EngineCore supports profiling (#10564)

Signed-off-by: Abatom <abzhonghua@gmail.com>
This commit is contained in:
Zhonghua Deng 2024-11-23 09:16:15 +08:00 committed by GitHub
parent 28598f3939
commit d345f409b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 68 additions and 9 deletions

View File

@ -68,6 +68,11 @@ class EngineCoreOutputs(msgspec.Struct,
outputs: List[EngineCoreOutput] outputs: List[EngineCoreOutput]
@dataclass
class EngineCoreProfile:
is_start: bool
class EngineCoreRequestType(enum.Enum): class EngineCoreRequestType(enum.Enum):
""" """
Request types defined as hex byte strings, so it can be sent over sockets Request types defined as hex byte strings, so it can be sent over sockets
@ -75,3 +80,4 @@ class EngineCoreRequestType(enum.Enum):
""" """
ADD = b'\x00' ADD = b'\x00'
ABORT = b'\x01' ABORT = b'\x01'
PROFILE = b'\x02'

View File

@ -346,10 +346,10 @@ class AsyncLLM(EngineClient):
logger.debug("Called check_health.") logger.debug("Called check_health.")
async def start_profile(self) -> None: async def start_profile(self) -> None:
raise ValueError("Not supported on V1 yet.") await self.engine_core.profile(True)
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
raise ValueError("Not supported on V1 yet.") await self.engine_core.profile(False)
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:

View File

@ -1,4 +1,5 @@
import multiprocessing import multiprocessing
import pickle
import queue import queue
import threading import threading
import time import time
@ -16,7 +17,8 @@ from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType) EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
@ -126,6 +128,9 @@ class EngineCore:
scheduler_output, output) scheduler_output, output)
return engine_core_outputs return engine_core_outputs
def profile(self, is_start=True):
self.model_executor.worker.profile(is_start)
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
@ -312,11 +317,14 @@ class EngineCoreProc(EngineCore):
self._last_logging_time = now self._last_logging_time = now
def _handle_client_request( def _handle_client_request(
self, request: Union[EngineCoreRequest, List[str]]) -> None: self, request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client.""" """Handle EngineCoreRequest or EngineCoreABORT from Client."""
if isinstance(request, EngineCoreRequest): if isinstance(request, EngineCoreRequest):
self.add_request(request) self.add_request(request)
elif isinstance(request, EngineCoreProfile):
self.model_executor.worker.profile(request.is_start)
else: else:
# TODO: make an EngineCoreAbort wrapper # TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list) assert isinstance(request, list)
@ -341,6 +349,8 @@ class EngineCoreProc(EngineCore):
request = decoder_add_req.decode(request_data) request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value: elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data) request = decoder_abort_req.decode(request_data)
elif request_type == EngineCoreRequestType.PROFILE.value:
request = pickle.loads(request_data)
else: else:
raise ValueError(f"Unknown RequestType: {request_type}") raise ValueError(f"Unknown RequestType: {request_type}")

View File

@ -9,7 +9,8 @@ import zmq.asyncio
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType) EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.serial_utils import PickleEncoder from vllm.v1.serial_utils import PickleEncoder
@ -58,6 +59,9 @@ class EngineCoreClient:
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError raise NotImplementedError
async def profile(self, is_start=True) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError raise NotImplementedError
@ -95,6 +99,9 @@ class InprocClient(EngineCoreClient):
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids) self.engine_core.abort_requests(request_ids)
async def profile(self, is_start=True) -> None:
self.engine_core.profile(is_start)
class MPClient(EngineCoreClient): class MPClient(EngineCoreClient):
""" """
@ -177,8 +184,10 @@ class SyncMPClient(MPClient):
engine_core_outputs = self.decoder.decode(frame.buffer).outputs engine_core_outputs = self.decoder.decode(frame.buffer).outputs
return engine_core_outputs return engine_core_outputs
def _send_input(self, request_type: EngineCoreRequestType, def _send_input(
request: Union[EngineCoreRequest, List[str]]) -> None: self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
# (RequestType, SerializedRequest) # (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
@ -190,6 +199,10 @@ class SyncMPClient(MPClient):
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids) self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile(self, is_start=True) -> None:
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore.""" """Asyncio-compatible client for multi-proc EngineCore."""
@ -206,7 +219,8 @@ class AsyncMPClient(MPClient):
async def _send_input( async def _send_input(
self, request_type: EngineCoreRequestType, self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) -> None: request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False) await self.input_socket.send_multipart(msg, copy=False)
@ -217,3 +231,7 @@ class AsyncMPClient(MPClient):
async def abort_requests_async(self, request_ids: List[str]) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None:
if len(request_ids) > 0: if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids) await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile(self, is_start=True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
@ -56,6 +57,22 @@ class Worker:
init_cached_hf_modules() init_cached_hf_modules()
self.model_runner = GPUModelRunner(vllm_config) self.model_runner = GPUModelRunner(vllm_config)
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def initialize(self): def initialize(self):
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
@ -184,6 +201,14 @@ class Worker:
# TODO(woosuk): Send the output to the engine process. # TODO(woosuk): Send the output to the engine process.
return output return output
def profile(self, is_start=True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def init_worker_distributed_environment( def init_worker_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,