From 76ed5340f0ec0e481593ea1a94459b4b55136a4f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 28 Oct 2024 14:35:17 -0700 Subject: [PATCH] [torch.compile] add deepseek v2 compile (#9775) Signed-off-by: youkaichao --- vllm/model_executor/models/deepseek_v2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 38114836..d4ad0c6b 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -28,6 +28,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, @@ -403,6 +404,7 @@ class DeepseekV2DecoderLayer(nn.Module): return hidden_states, residual +@support_torch_compile class DeepseekV2Model(nn.Module): fall_back_to_pt_during_load = False