[Bugfix] More faithful implementation of Gemma (#3653)

This commit is contained in:
Woosuk Kwon 2024-03-27 09:37:18 -07:00 committed by GitHub
parent 8f44facddd
commit 82c540bebf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import List, Optional, Tuple
import torch
@ -22,6 +23,7 @@ from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@ -40,6 +42,34 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
logger = init_logger(__name__)
@lru_cache(maxsize=None)
def _get_gemma_act_fn(
hidden_act: Optional[str],
hidden_activation: Optional[str],
) -> nn.Module:
if hidden_activation is None:
if hidden_act is not None:
logger.warning(
"Gemma's activation function was incorrectly set to exact GeLU "
"in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
f"`{hidden_act}`, edit the config JSON to set "
f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 "
"for more details.")
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu_pytorch_tanh":
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu":
return GeluAndMul(approximate="none")
else:
raise ValueError(f"Activation function {hidden_act} is not "
"supported for Gemma models.")
class GemmaMLP(nn.Module):
@ -47,6 +77,8 @@ class GemmaMLP(nn.Module):
self,
hidden_size: int,
intermediate_size: int,
hidden_act: Optional[str] = None,
hidden_activation: Optional[str] = None,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
@ -58,7 +90,7 @@ class GemmaMLP(nn.Module):
hidden_size,
bias=False,
linear_method=linear_method)
self.act_fn = GeluAndMul()
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
@ -162,6 +194,8 @@ class GemmaDecoderLayer(nn.Module):
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None),
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size,
@ -218,6 +252,13 @@ class GemmaModel(nn.Module):
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
# data type such as bfloat16, not float32.
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
def forward(
self,
input_ids: torch.Tensor,
@ -226,8 +267,7 @@ class GemmaModel(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Normalize the embedding by sqrt(hidden_size)
hidden_states *= self.config.hidden_size**0.5
hidden_states *= self.normalizer
residual = None
for i in range(len(self.layers)):