Fix rope cache key error (#1867)

This commit is contained in:
Roy 2023-12-01 00:29:28 +08:00 committed by GitHub
parent 8d8c2f6ffe
commit d27f4bae39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -284,9 +284,10 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
key = (head_size, rotary_dim, max_position, base, is_neox_style, key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling) tuple(rope_scaling.items()) if rope_scaling is not None else None)
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
if rope_scaling is None: if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style) is_neox_style)