From d7a4f2207bd0ff31cacf311a05266557d66e474e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 11 Nov 2024 11:05:57 -0800 Subject: [PATCH] [V1] Do not use inductor for piecewise CUDA graphs (#10225) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 24690485..1e20920d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -404,15 +404,14 @@ class GPUModelRunner: def load_model(self) -> None: if self.use_cuda_graph: - # FIXME(woosuk): Currently, the custom ops are not supported - # in the piecewise compilation mode. We rely on TorchInductor - # to optimize the model. + # FIXME(woosuk): Currently, we do not use inductor to reduce the + # compilation time and any potential issues with the inductor. os.environ["VLLM_CUSTOM_OPS"] = "none" set_compilation_config( CompilationConfig( use_cudagraph=True, non_cudagraph_ops=["vllm.unified_v1_flash_attention"], - use_inductor=True, + use_inductor=False, )) logger.info("Starting to load model %s...", self.model_config.model)