[Bugfix] More faithful implementation of Gemma (#3653)
This commit is contained in:
parent
8f44facddd
commit
82c540bebf
@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
||||||
|
from functools import lru_cache
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -22,6 +23,7 @@ from transformers import GemmaConfig
|
|||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
@ -40,6 +42,34 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
|||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def _get_gemma_act_fn(
|
||||||
|
hidden_act: Optional[str],
|
||||||
|
hidden_activation: Optional[str],
|
||||||
|
) -> nn.Module:
|
||||||
|
if hidden_activation is None:
|
||||||
|
if hidden_act is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Gemma's activation function was incorrectly set to exact GeLU "
|
||||||
|
"in the config JSON file when it was initially released. "
|
||||||
|
"Changing the activation function to approximate GeLU "
|
||||||
|
"(`gelu_pytorch_tanh`). If you want to use the legacy "
|
||||||
|
f"`{hidden_act}`, edit the config JSON to set "
|
||||||
|
f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
|
||||||
|
"See https://github.com/huggingface/transformers/pull/29402 "
|
||||||
|
"for more details.")
|
||||||
|
return GeluAndMul(approximate="tanh")
|
||||||
|
elif hidden_activation == "gelu_pytorch_tanh":
|
||||||
|
return GeluAndMul(approximate="tanh")
|
||||||
|
elif hidden_activation == "gelu":
|
||||||
|
return GeluAndMul(approximate="none")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Activation function {hidden_act} is not "
|
||||||
|
"supported for Gemma models.")
|
||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
class GemmaMLP(nn.Module):
|
||||||
|
|
||||||
@ -47,6 +77,8 @@ class GemmaMLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
|
hidden_act: Optional[str] = None,
|
||||||
|
hidden_activation: Optional[str] = None,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -58,7 +90,7 @@ class GemmaMLP(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method)
|
linear_method=linear_method)
|
||||||
self.act_fn = GeluAndMul()
|
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
@ -162,6 +194,8 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
self.mlp = GemmaMLP(
|
self.mlp = GemmaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
hidden_activation=getattr(config, "hidden_activation", None),
|
||||||
linear_method=linear_method,
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
@ -218,6 +252,13 @@ class GemmaModel(nn.Module):
|
|||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
# Normalize the embedding by sqrt(hidden_size)
|
||||||
|
# The normalizer's data type should be downcasted to the model's
|
||||||
|
# data type such as bfloat16, not float32.
|
||||||
|
# See https://github.com/huggingface/transformers/pull/29402
|
||||||
|
normalizer = self.config.hidden_size**0.5
|
||||||
|
self.register_buffer("normalizer", torch.tensor(normalizer))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -226,8 +267,7 @@ class GemmaModel(nn.Module):
|
|||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
# Normalize the embedding by sqrt(hidden_size)
|
hidden_states *= self.normalizer
|
||||||
hidden_states *= self.config.hidden_size**0.5
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user