From 708989341ef6361a5981d890a0e2f1b794323458 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 16:18:45 -0700 Subject: [PATCH] [misc] add a flag to enable compile (#7092) --- vllm/envs.py | 4 ++++ vllm/worker/model_runner.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 5b8a65bd..595058bc 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -174,6 +174,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # Internal flag to enable Dynamo graph capture + "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": + lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 77734428..f9c26e0c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,6 +23,7 @@ except ImportError: BatchPrefillWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 +import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, @@ -786,6 +787,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") + if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: + self.model = torch.compile(self.model, + fullgraph=True, + backend="eager") + def save_sharded_state( self, path: str,