From d27f4bae393214b4e7715fc3cb5754d4bf801bce Mon Sep 17 00:00:00 2001 From: Roy Date: Fri, 1 Dec 2023 00:29:28 +0800 Subject: [PATCH] Fix rope cache key error (#1867) --- vllm/model_executor/layers/rotary_embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 9c109cc3..b3a4d38b 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -284,9 +284,10 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling) + tuple(rope_scaling.items()) if rope_scaling is not None else None) if key in _ROPE_DICT: return _ROPE_DICT[key] + if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style)