Add support for a rope extension method (#6553)
This commit is contained in:
parent
1689219ebf
commit
c5df56f88b
@ -151,6 +151,15 @@ class ModelConfig:
|
|||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
|
|
||||||
|
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
|
||||||
|
and getattr(self.hf_config, "rope_scaling", None) is None):
|
||||||
|
# Note(simon): this is a special case for a model that doesn't
|
||||||
|
# supply rope_scaling. We should remove this once the model is
|
||||||
|
# updated.
|
||||||
|
self.hf_config.update({"rope_scaling": {
|
||||||
|
"type": "extended",
|
||||||
|
}})
|
||||||
|
|
||||||
if (not self.disable_sliding_window
|
if (not self.disable_sliding_window
|
||||||
and self.hf_text_config.model_type == "gemma2"
|
and self.hf_text_config.model_type == "gemma2"
|
||||||
and self.hf_text_config.sliding_window is not None):
|
and self.hf_text_config.sliding_window is not None):
|
||||||
@ -1442,8 +1451,9 @@ def _get_and_verify_max_len(
|
|||||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
# The correct one should be "longrope", kept "su" here
|
# The correct one should be "longrope", kept "su" here
|
||||||
# to be backward compatible
|
# to be backward compatible
|
||||||
if rope_scaling is not None and rope_scaling["type"] != "su" \
|
if rope_scaling is not None and rope_scaling["type"] not in {
|
||||||
and rope_scaling["type"] != "longrope":
|
"su", "longrope", "extended"
|
||||||
|
}:
|
||||||
if disable_sliding_window:
|
if disable_sliding_window:
|
||||||
# TODO(robertgshaw): Find a model that supports rope_scaling
|
# TODO(robertgshaw): Find a model that supports rope_scaling
|
||||||
# with sliding window to see if this case should be allowed.
|
# with sliding window to see if this case should be allowed.
|
||||||
|
|||||||
@ -733,6 +733,36 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
|
|||||||
return inv_freq
|
return inv_freq
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendedRotaryEmbedding(RotaryEmbedding):
|
||||||
|
|
||||||
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
|
inv_freqs = super()._compute_inv_freq(base)
|
||||||
|
return self.apply_scaling(inv_freqs)
|
||||||
|
|
||||||
|
def apply_scaling(self, freqs: torch.Tensor):
|
||||||
|
scale_factor = 8
|
||||||
|
low_freq_factor = 1
|
||||||
|
high_freq_factor = 4
|
||||||
|
old_context_len = 8192
|
||||||
|
|
||||||
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
new_freqs = []
|
||||||
|
for freq in freqs:
|
||||||
|
wavelen = 2 * math.pi / freq
|
||||||
|
if wavelen < high_freq_wavelen:
|
||||||
|
new_freqs.append(freq)
|
||||||
|
elif wavelen > low_freq_wavelen:
|
||||||
|
new_freqs.append(freq / scale_factor)
|
||||||
|
else:
|
||||||
|
assert low_freq_wavelen != high_freq_wavelen
|
||||||
|
smooth = (old_context_len / wavelen - low_freq_factor) / (
|
||||||
|
high_freq_factor - low_freq_factor)
|
||||||
|
new_freqs.append((1 - smooth) * freq / scale_factor +
|
||||||
|
smooth * freq)
|
||||||
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||||
|
|
||||||
|
|
||||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||||
|
|
||||||
|
|
||||||
@ -767,9 +797,13 @@ def get_rope(
|
|||||||
scaling_type = rope_scaling["type"]
|
scaling_type = rope_scaling["type"]
|
||||||
# The correct one should be "longrope" but keep "su" here
|
# The correct one should be "longrope" but keep "su" here
|
||||||
# for backward compatible
|
# for backward compatible
|
||||||
if scaling_type != "su" and scaling_type != "longrope":
|
if scaling_type not in {"su", "longrope", "extended"}:
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
if scaling_type == "linear":
|
if scaling_type == "extended":
|
||||||
|
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
|
||||||
|
max_position, base,
|
||||||
|
is_neox_style, dtype)
|
||||||
|
elif scaling_type == "linear":
|
||||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||||
max_position, base,
|
max_position, base,
|
||||||
is_neox_style,
|
is_neox_style,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user