[Model] Rename Phi3 rope scaling type (#5595)
This commit is contained in:
parent
e2b85cf86a
commit
9333fb8eb9
@ -1287,7 +1287,10 @@ def _get_and_verify_max_len(
|
|||||||
derived_max_model_len = default_max_len
|
derived_max_model_len = default_max_len
|
||||||
|
|
||||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
if rope_scaling is not None and rope_scaling["type"] != "su":
|
# The correct one should be "longrope", kept "su" here
|
||||||
|
# to be backward compatible
|
||||||
|
if rope_scaling is not None and rope_scaling["type"] != "su" \
|
||||||
|
and rope_scaling["type"] != "longrope":
|
||||||
if disable_sliding_window:
|
if disable_sliding_window:
|
||||||
# TODO(robertgshaw): Find a model that supports rope_scaling
|
# TODO(robertgshaw): Find a model that supports rope_scaling
|
||||||
# with sliding window to see if this case should be allowed.
|
# with sliding window to see if this case should be allowed.
|
||||||
|
|||||||
@ -467,7 +467,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
class Phi3SuScaledRotaryEmbedding(nn.Module):
|
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||||
"""Phi3 family of models scaled rotary embedding.
|
"""Phi3 family of models scaled rotary embedding.
|
||||||
|
|
||||||
Based on the original RotaryEmbedding implementation.
|
Based on the original RotaryEmbedding implementation.
|
||||||
@ -491,11 +491,12 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
if rotary_dim != head_size:
|
if rotary_dim != head_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
|
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
|
||||||
head_size ({rotary_dim}!={head_size}).")
|
rotary_dim != head_size ({rotary_dim}!={head_size}).")
|
||||||
if is_neox_style is False:
|
if is_neox_style is False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
|
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
|
||||||
|
)
|
||||||
|
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
@ -608,7 +609,9 @@ def get_rope(
|
|||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
else:
|
else:
|
||||||
scaling_type = rope_scaling["type"]
|
scaling_type = rope_scaling["type"]
|
||||||
if scaling_type != "su":
|
# The correct one should be "longrope" but keep "su" here
|
||||||
|
# for backward compatible
|
||||||
|
if scaling_type != "su" and scaling_type != "longrope":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
if scaling_type == "linear":
|
if scaling_type == "linear":
|
||||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||||
@ -633,7 +636,9 @@ def get_rope(
|
|||||||
base, is_neox_style,
|
base, is_neox_style,
|
||||||
scaling_factor, dtype,
|
scaling_factor, dtype,
|
||||||
**extra_kwargs)
|
**extra_kwargs)
|
||||||
elif scaling_type == "su":
|
# The correct one should be "longrope" but keep "su" here
|
||||||
|
# for backward compatible
|
||||||
|
elif scaling_type == "su" or scaling_type == "longrope":
|
||||||
short_factor = rope_scaling["short_factor"]
|
short_factor = rope_scaling["short_factor"]
|
||||||
long_factor = rope_scaling["long_factor"]
|
long_factor = rope_scaling["long_factor"]
|
||||||
original_max_position = rope_scaling[
|
original_max_position = rope_scaling[
|
||||||
@ -643,7 +648,7 @@ def get_rope(
|
|||||||
for k, v in rope_scaling.items()
|
for k, v in rope_scaling.items()
|
||||||
if k in ("short_mscale", "long_mscale")
|
if k in ("short_mscale", "long_mscale")
|
||||||
}
|
}
|
||||||
rotary_emb = Phi3SuScaledRotaryEmbedding(
|
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||||
head_size, rotary_dim, max_position, original_max_position,
|
head_size, rotary_dim, max_position, original_max_position,
|
||||||
base, is_neox_style, dtype, short_factor, long_factor,
|
base, is_neox_style, dtype, short_factor, long_factor,
|
||||||
**extra_kwargs)
|
**extra_kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user