[Model] Rename Phi3 rope scaling type (#5595)
This commit is contained in:
parent
e2b85cf86a
commit
9333fb8eb9
@ -1287,7 +1287,10 @@ def _get_and_verify_max_len(
|
||||
derived_max_model_len = default_max_len
|
||||
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if rope_scaling is not None and rope_scaling["type"] != "su":
|
||||
# The correct one should be "longrope", kept "su" here
|
||||
# to be backward compatible
|
||||
if rope_scaling is not None and rope_scaling["type"] != "su" \
|
||||
and rope_scaling["type"] != "longrope":
|
||||
if disable_sliding_window:
|
||||
# TODO(robertgshaw): Find a model that supports rope_scaling
|
||||
# with sliding window to see if this case should be allowed.
|
||||
|
||||
@ -467,7 +467,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
return cache
|
||||
|
||||
|
||||
class Phi3SuScaledRotaryEmbedding(nn.Module):
|
||||
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
"""Phi3 family of models scaled rotary embedding.
|
||||
|
||||
Based on the original RotaryEmbedding implementation.
|
||||
@ -491,11 +491,12 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
|
||||
|
||||
if rotary_dim != head_size:
|
||||
raise ValueError(
|
||||
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
|
||||
head_size ({rotary_dim}!={head_size}).")
|
||||
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
|
||||
rotary_dim != head_size ({rotary_dim}!={head_size}).")
|
||||
if is_neox_style is False:
|
||||
raise ValueError(
|
||||
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
|
||||
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
|
||||
)
|
||||
|
||||
self.head_size = head_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
@ -608,7 +609,9 @@ def get_rope(
|
||||
is_neox_style, dtype)
|
||||
else:
|
||||
scaling_type = rope_scaling["type"]
|
||||
if scaling_type != "su":
|
||||
# The correct one should be "longrope" but keep "su" here
|
||||
# for backward compatible
|
||||
if scaling_type != "su" and scaling_type != "longrope":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
@ -633,7 +636,9 @@ def get_rope(
|
||||
base, is_neox_style,
|
||||
scaling_factor, dtype,
|
||||
**extra_kwargs)
|
||||
elif scaling_type == "su":
|
||||
# The correct one should be "longrope" but keep "su" here
|
||||
# for backward compatible
|
||||
elif scaling_type == "su" or scaling_type == "longrope":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
long_factor = rope_scaling["long_factor"]
|
||||
original_max_position = rope_scaling[
|
||||
@ -643,7 +648,7 @@ def get_rope(
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3SuScaledRotaryEmbedding(
|
||||
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, original_max_position,
|
||||
base, is_neox_style, dtype, short_factor, long_factor,
|
||||
**extra_kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user