From d130b573a0162173002b97e2112c6c1c10d0ca8e Mon Sep 17 00:00:00 2001 From: HUANG Fei Date: Tue, 21 May 2024 13:22:22 +0800 Subject: [PATCH] [Model] add rope_scaling support for qwen2 (#4930) --- vllm/model_executor/models/qwen2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 31ba6441..97ab6168 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -89,7 +89,8 @@ class Qwen2Attention(nn.Module): use_sliding_window: bool = False, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + sliding_window: Optional[int] = None, + rope_scaling: Optional[Tuple] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -133,6 +134,7 @@ class Qwen2Attention(nn.Module): rotary_dim=self.head_dim, max_position=max_position, base=self.rope_theta, + rope_scaling=rope_scaling, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -169,6 +171,7 @@ class Qwen2DecoderLayer(nn.Module): self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) use_sliding_window = (config.use_sliding_window and layer_idx < config.max_window_layers) self.self_attn = Qwen2Attention( @@ -180,7 +183,8 @@ class Qwen2DecoderLayer(nn.Module): use_sliding_window=use_sliding_window, cache_config=cache_config, quant_config=quant_config, - sliding_window=config.sliding_window) + sliding_window=config.sliding_window, + rope_scaling=rope_scaling) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size,