Fix rope cache key error (#1867)
This commit is contained in:
parent
8d8c2f6ffe
commit
d27f4bae39
@ -284,9 +284,10 @@ def get_rope(
|
|||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
) -> RotaryEmbedding:
|
) -> RotaryEmbedding:
|
||||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
rope_scaling)
|
tuple(rope_scaling.items()) if rope_scaling is not None else None)
|
||||||
if key in _ROPE_DICT:
|
if key in _ROPE_DICT:
|
||||||
return _ROPE_DICT[key]
|
return _ROPE_DICT[key]
|
||||||
|
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style)
|
is_neox_style)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user