From 63b2206ad01499921428ba50c85a18c92772f26c Mon Sep 17 00:00:00 2001 From: Jee Li Date: Thu, 30 Nov 2023 15:06:27 +0800 Subject: [PATCH] Avoid multiple instantiations of the RoPE class (#1828) --- vllm/model_executor/layers/rotary_embedding.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 0bde4cef..9c109cc3 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -272,6 +272,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): return cache +_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + def get_rope( head_size: int, rotary_dim: int, @@ -280,6 +283,10 @@ def get_rope( is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: + key = (head_size, rotary_dim, max_position, base, is_neox_style, + rope_scaling) + if key in _ROPE_DICT: + return _ROPE_DICT[key] if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style) @@ -312,4 +319,5 @@ def get_rope( **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb return rotary_emb