[Bugfix] Fix Phi-3 Long RoPE scaling implementation (#5628)
This commit is contained in:
parent
6820724e51
commit
59a1eb59c9
@ -507,8 +507,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
short_factor: List[float],
|
short_factor: List[float],
|
||||||
long_factor: List[float],
|
long_factor: List[float],
|
||||||
short_mscale: float = 1.1,
|
short_mscale: float = 1.0,
|
||||||
long_mscale: float = 1.225,
|
long_mscale: float = 1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -530,6 +530,16 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
self.short_mscale = short_mscale
|
self.short_mscale = short_mscale
|
||||||
self.long_mscale = long_mscale
|
self.long_mscale = long_mscale
|
||||||
|
|
||||||
|
scale = (self.max_position_embeddings /
|
||||||
|
self.original_max_position_embeddings)
|
||||||
|
|
||||||
|
if scale <= 1.0:
|
||||||
|
self.scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
self.scaling_factor = math.sqrt(
|
||||||
|
1 + math.log(scale) /
|
||||||
|
math.log(self.original_max_position_embeddings))
|
||||||
|
|
||||||
short_cache = self._compute_cos_sin_cache(
|
short_cache = self._compute_cos_sin_cache(
|
||||||
original_max_position_embeddings, short_factor, short_mscale)
|
original_max_position_embeddings, short_factor, short_mscale)
|
||||||
short_cache = short_cache.to(dtype)
|
short_cache = short_cache.to(dtype)
|
||||||
@ -565,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
inv_freq = self._compute_inv_freq(rescale_factors)
|
inv_freq = self._compute_inv_freq(rescale_factors)
|
||||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos() * mscale
|
cos = freqs.cos() * mscale * self.scaling_factor
|
||||||
sin = freqs.sin() * mscale
|
sin = freqs.sin() * mscale * self.scaling_factor
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user