From 59a1eb59c9cb383e5ea36d7253f81ff2ea7766cc Mon Sep 17 00:00:00 2001 From: Shukant Pal Date: Tue, 18 Jun 2024 18:46:38 -0700 Subject: [PATCH] [Bugfix] Fix Phi-3 Long RoPE scaling implementation (#5628) --- vllm/model_executor/layers/rotary_embedding.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 9c0a74cd..a0b19046 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -507,8 +507,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): dtype: torch.dtype, short_factor: List[float], long_factor: List[float], - short_mscale: float = 1.1, - long_mscale: float = 1.225, + short_mscale: float = 1.0, + long_mscale: float = 1.0, ): super().__init__() @@ -530,6 +530,16 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): self.short_mscale = short_mscale self.long_mscale = long_mscale + scale = (self.max_position_embeddings / + self.original_max_position_embeddings) + + if scale <= 1.0: + self.scaling_factor = 1.0 + else: + self.scaling_factor = math.sqrt( + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) + short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) short_cache = short_cache.to(dtype) @@ -565,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): inv_freq = self._compute_inv_freq(rescale_factors) t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * mscale - sin = freqs.sin() * mscale + cos = freqs.cos() * mscale * self.scaling_factor + sin = freqs.sin() * mscale * self.scaling_factor cache = torch.cat((cos, sin), dim=-1) return cache