diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 0bde4cef..9c109cc3 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -272,6 +272,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): return cache +_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + def get_rope( head_size: int, rotary_dim: int, @@ -280,6 +283,10 @@ def get_rope( is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: + key = (head_size, rotary_dim, max_position, base, is_neox_style, + rope_scaling) + if key in _ROPE_DICT: + return _ROPE_DICT[key] if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style) @@ -312,4 +319,5 @@ def get_rope( **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb return rotary_emb