diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index e6ee2b96..0562b71a 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -93,11 +93,6 @@ class RotaryEmbedding(CustomOp): def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" - # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. - # However, we use `torch.arange(..., dtype=torch.float)` instead to - # avoid numerical issues with large base values (e.g., 10000000). - # This may cause a slight numerical difference between the HF - # implementation and ours. # NOTE(woosuk): To exactly match the HF implementation, we need to # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause @@ -724,16 +719,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): return query, key -class GemmaRotaryEmbedding(RotaryEmbedding): - - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 - inv_freq = 1.0 / (base**( - torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / - self.rotary_dim)) - return inv_freq - - class Llama3RotaryEmbedding(RotaryEmbedding): def __init__( diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 14d15788..7a9ee3d9 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -148,14 +148,12 @@ class GemmaAttention(nn.Module): quant_config=quant_config, ) - # TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, - max_position_embeddings=max_position_embeddings, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) self.attn = Attention(self.num_heads, self.head_dim, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index aa9cff02..ff547c2c 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -130,14 +130,12 @@ class Gemma2Attention(nn.Module): bias=config.attention_bias, quant_config=quant_config, ) - # TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, - self.head_dim, - max_position_embeddings, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) # FIXME(woosuk): While Gemma 2 uses sliding window attention for every