[Model] Rename Phi3 rope scaling type (#5595)

This commit is contained in:
Amit Garg 2024-06-17 09:04:14 -07:00 committed by GitHub
parent e2b85cf86a
commit 9333fb8eb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 8 deletions

View File

@ -1287,7 +1287,10 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su":
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_scaling is not None and rope_scaling["type"] != "su" \
and rope_scaling["type"] != "longrope":
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.

View File

@ -467,7 +467,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return cache
class Phi3SuScaledRotaryEmbedding(nn.Module):
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
@ -491,11 +491,12 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
if rotary_dim != head_size:
raise ValueError(
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
head_size ({rotary_dim}!={head_size}).")
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
@ -608,7 +609,9 @@ def get_rope(
is_neox_style, dtype)
else:
scaling_type = rope_scaling["type"]
if scaling_type != "su":
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type != "su" and scaling_type != "longrope":
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
@ -633,7 +636,9 @@ def get_rope(
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "su":
# The correct one should be "longrope" but keep "su" here
# for backward compatible
elif scaling_type == "su" or scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
@ -643,7 +648,7 @@ def get_rope(
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)