[V1] EngineCore supports profiling (#10564)
Signed-off-by: Abatom <abzhonghua@gmail.com>
This commit is contained in:
parent
28598f3939
commit
d345f409b7
@ -68,6 +68,11 @@ class EngineCoreOutputs(msgspec.Struct,
|
|||||||
outputs: List[EngineCoreOutput]
|
outputs: List[EngineCoreOutput]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineCoreProfile:
|
||||||
|
is_start: bool
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreRequestType(enum.Enum):
|
class EngineCoreRequestType(enum.Enum):
|
||||||
"""
|
"""
|
||||||
Request types defined as hex byte strings, so it can be sent over sockets
|
Request types defined as hex byte strings, so it can be sent over sockets
|
||||||
@ -75,3 +80,4 @@ class EngineCoreRequestType(enum.Enum):
|
|||||||
"""
|
"""
|
||||||
ADD = b'\x00'
|
ADD = b'\x00'
|
||||||
ABORT = b'\x01'
|
ABORT = b'\x01'
|
||||||
|
PROFILE = b'\x02'
|
||||||
|
|||||||
@ -346,10 +346,10 @@ class AsyncLLM(EngineClient):
|
|||||||
logger.debug("Called check_health.")
|
logger.debug("Called check_health.")
|
||||||
|
|
||||||
async def start_profile(self) -> None:
|
async def start_profile(self) -> None:
|
||||||
raise ValueError("Not supported on V1 yet.")
|
await self.engine_core.profile(True)
|
||||||
|
|
||||||
async def stop_profile(self) -> None:
|
async def stop_profile(self) -> None:
|
||||||
raise ValueError("Not supported on V1 yet.")
|
await self.engine_core.profile(False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import pickle
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -16,7 +17,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.v1.core.scheduler import Scheduler
|
from vllm.v1.core.scheduler import Scheduler
|
||||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||||
EngineCoreRequest, EngineCoreRequestType)
|
EngineCoreProfile, EngineCoreRequest,
|
||||||
|
EngineCoreRequestType)
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapper
|
from vllm.v1.engine.mm_input_mapper import MMInputMapper
|
||||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
from vllm.v1.executor.gpu_executor import GPUExecutor
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
@ -126,6 +128,9 @@ class EngineCore:
|
|||||||
scheduler_output, output)
|
scheduler_output, output)
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
|
def profile(self, is_start=True):
|
||||||
|
self.model_executor.worker.profile(is_start)
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreProc(EngineCore):
|
class EngineCoreProc(EngineCore):
|
||||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||||
@ -312,11 +317,14 @@ class EngineCoreProc(EngineCore):
|
|||||||
self._last_logging_time = now
|
self._last_logging_time = now
|
||||||
|
|
||||||
def _handle_client_request(
|
def _handle_client_request(
|
||||||
self, request: Union[EngineCoreRequest, List[str]]) -> None:
|
self, request: Union[EngineCoreRequest, EngineCoreProfile,
|
||||||
|
List[str]]) -> None:
|
||||||
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
||||||
|
|
||||||
if isinstance(request, EngineCoreRequest):
|
if isinstance(request, EngineCoreRequest):
|
||||||
self.add_request(request)
|
self.add_request(request)
|
||||||
|
elif isinstance(request, EngineCoreProfile):
|
||||||
|
self.model_executor.worker.profile(request.is_start)
|
||||||
else:
|
else:
|
||||||
# TODO: make an EngineCoreAbort wrapper
|
# TODO: make an EngineCoreAbort wrapper
|
||||||
assert isinstance(request, list)
|
assert isinstance(request, list)
|
||||||
@ -341,6 +349,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
request = decoder_add_req.decode(request_data)
|
request = decoder_add_req.decode(request_data)
|
||||||
elif request_type == EngineCoreRequestType.ABORT.value:
|
elif request_type == EngineCoreRequestType.ABORT.value:
|
||||||
request = decoder_abort_req.decode(request_data)
|
request = decoder_abort_req.decode(request_data)
|
||||||
|
elif request_type == EngineCoreRequestType.PROFILE.value:
|
||||||
|
request = pickle.loads(request_data)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RequestType: {request_type}")
|
raise ValueError(f"Unknown RequestType: {request_type}")
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,8 @@ import zmq.asyncio
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_open_zmq_ipc_path
|
from vllm.utils import get_open_zmq_ipc_path
|
||||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||||
EngineCoreRequest, EngineCoreRequestType)
|
EngineCoreProfile, EngineCoreRequest,
|
||||||
|
EngineCoreRequestType)
|
||||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||||
from vllm.v1.serial_utils import PickleEncoder
|
from vllm.v1.serial_utils import PickleEncoder
|
||||||
|
|
||||||
@ -58,6 +59,9 @@ class EngineCoreClient:
|
|||||||
def add_request(self, request: EngineCoreRequest) -> None:
|
def add_request(self, request: EngineCoreRequest) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def profile(self, is_start=True) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -95,6 +99,9 @@ class InprocClient(EngineCoreClient):
|
|||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
self.engine_core.abort_requests(request_ids)
|
self.engine_core.abort_requests(request_ids)
|
||||||
|
|
||||||
|
async def profile(self, is_start=True) -> None:
|
||||||
|
self.engine_core.profile(is_start)
|
||||||
|
|
||||||
|
|
||||||
class MPClient(EngineCoreClient):
|
class MPClient(EngineCoreClient):
|
||||||
"""
|
"""
|
||||||
@ -177,8 +184,10 @@ class SyncMPClient(MPClient):
|
|||||||
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
|
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
def _send_input(self, request_type: EngineCoreRequestType,
|
def _send_input(
|
||||||
request: Union[EngineCoreRequest, List[str]]) -> None:
|
self, request_type: EngineCoreRequestType,
|
||||||
|
request: Union[EngineCoreRequest, EngineCoreProfile,
|
||||||
|
List[str]]) -> None:
|
||||||
|
|
||||||
# (RequestType, SerializedRequest)
|
# (RequestType, SerializedRequest)
|
||||||
msg = (request_type.value, self.encoder.encode(request))
|
msg = (request_type.value, self.encoder.encode(request))
|
||||||
@ -190,6 +199,10 @@ class SyncMPClient(MPClient):
|
|||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||||
|
|
||||||
|
async def profile(self, is_start=True) -> None:
|
||||||
|
self._send_input(EngineCoreRequestType.PROFILE,
|
||||||
|
EngineCoreProfile(is_start))
|
||||||
|
|
||||||
|
|
||||||
class AsyncMPClient(MPClient):
|
class AsyncMPClient(MPClient):
|
||||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||||
@ -206,7 +219,8 @@ class AsyncMPClient(MPClient):
|
|||||||
|
|
||||||
async def _send_input(
|
async def _send_input(
|
||||||
self, request_type: EngineCoreRequestType,
|
self, request_type: EngineCoreRequestType,
|
||||||
request: Union[EngineCoreRequest, List[str]]) -> None:
|
request: Union[EngineCoreRequest, EngineCoreProfile,
|
||||||
|
List[str]]) -> None:
|
||||||
|
|
||||||
msg = (request_type.value, self.encoder.encode(request))
|
msg = (request_type.value, self.encoder.encode(request))
|
||||||
await self.input_socket.send_multipart(msg, copy=False)
|
await self.input_socket.send_multipart(msg, copy=False)
|
||||||
@ -217,3 +231,7 @@ class AsyncMPClient(MPClient):
|
|||||||
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
||||||
if len(request_ids) > 0:
|
if len(request_ids) > 0:
|
||||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||||
|
|
||||||
|
async def profile(self, is_start=True) -> None:
|
||||||
|
await self._send_input(EngineCoreRequestType.PROFILE,
|
||||||
|
EngineCoreProfile(is_start))
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
@ -56,6 +57,22 @@ class Worker:
|
|||||||
init_cached_hf_modules()
|
init_cached_hf_modules()
|
||||||
|
|
||||||
self.model_runner = GPUModelRunner(vllm_config)
|
self.model_runner = GPUModelRunner(vllm_config)
|
||||||
|
# Torch profiler. Enabled and configured through env vars:
|
||||||
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||||
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||||
|
torch_profiler_trace_dir)
|
||||||
|
self.profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
|
torch_profiler_trace_dir, use_gzip=True))
|
||||||
|
else:
|
||||||
|
self.profiler = None
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
if self.device_config.device.type == "cuda":
|
if self.device_config.device.type == "cuda":
|
||||||
@ -184,6 +201,14 @@ class Worker:
|
|||||||
# TODO(woosuk): Send the output to the engine process.
|
# TODO(woosuk): Send the output to the engine process.
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def profile(self, is_start=True):
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
if is_start:
|
||||||
|
self.profiler.start()
|
||||||
|
else:
|
||||||
|
self.profiler.stop()
|
||||||
|
|
||||||
|
|
||||||
def init_worker_distributed_environment(
|
def init_worker_distributed_environment(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user