From 82c540bebf599e3208683bebf87db56175145ebd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Mar 2024 09:37:18 -0700 Subject: [PATCH] [Bugfix] More faithful implementation of Gemma (#3653) --- vllm/model_executor/models/gemma.py | 46 +++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index a5432a03..08609532 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" +from functools import lru_cache from typing import List, Optional, Tuple import torch @@ -22,6 +23,7 @@ from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -40,6 +42,34 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +logger = init_logger(__name__) + + +@lru_cache(maxsize=None) +def _get_gemma_act_fn( + hidden_act: Optional[str], + hidden_activation: Optional[str], +) -> nn.Module: + if hidden_activation is None: + if hidden_act is not None: + logger.warning( + "Gemma's activation function was incorrectly set to exact GeLU " + "in the config JSON file when it was initially released. " + "Changing the activation function to approximate GeLU " + "(`gelu_pytorch_tanh`). If you want to use the legacy " + f"`{hidden_act}`, edit the config JSON to set " + f"`hidden_activation={hidden_act}` instead of `hidden_act`. " + "See https://github.com/huggingface/transformers/pull/29402 " + "for more details.") + return GeluAndMul(approximate="tanh") + elif hidden_activation == "gelu_pytorch_tanh": + return GeluAndMul(approximate="tanh") + elif hidden_activation == "gelu": + return GeluAndMul(approximate="none") + else: + raise ValueError(f"Activation function {hidden_act} is not " + "supported for Gemma models.") + class GemmaMLP(nn.Module): @@ -47,6 +77,8 @@ class GemmaMLP(nn.Module): self, hidden_size: int, intermediate_size: int, + hidden_act: Optional[str] = None, + hidden_activation: Optional[str] = None, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() @@ -58,7 +90,7 @@ class GemmaMLP(nn.Module): hidden_size, bias=False, linear_method=linear_method) - self.act_fn = GeluAndMul() + self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) def forward(self, x): gate_up, _ = self.gate_up_proj(x) @@ -162,6 +194,8 @@ class GemmaDecoderLayer(nn.Module): self.mlp = GemmaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + hidden_activation=getattr(config, "hidden_activation", None), linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, @@ -218,6 +252,13 @@ class GemmaModel(nn.Module): ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + def forward( self, input_ids: torch.Tensor, @@ -226,8 +267,7 @@ class GemmaModel(nn.Module): attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) - # Normalize the embedding by sqrt(hidden_size) - hidden_states *= self.config.hidden_size**0.5 + hidden_states *= self.normalizer residual = None for i in range(len(self.layers)):