Use Llama RMSNorm custom op for Gemma (#2974)
This commit is contained in:
parent
344020c926
commit
95529e3253
@ -22,6 +22,7 @@ from transformers import GemmaConfig
|
|||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@ -40,21 +41,6 @@ from vllm.sequence import SamplerOutput
|
|||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class GemmaRMSNorm(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.weight = nn.Parameter(torch.zeros(dim))
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
output = self._norm(x.float()).type_as(x)
|
|
||||||
return output * (1 + self.weight)
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
class GemmaMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -185,10 +171,10 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
linear_method=linear_method,
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -196,25 +182,27 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: KVCache,
|
kv_cache: KVCache,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = hidden_states
|
if residual is None:
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
input_metadata=input_metadata,
|
input_metadata=input_metadata,
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states, residual)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
return hidden_states, residual
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaModel(nn.Module):
|
class GemmaModel(nn.Module):
|
||||||
@ -235,7 +223,7 @@ class GemmaModel(nn.Module):
|
|||||||
GemmaDecoderLayer(config, linear_method)
|
GemmaDecoderLayer(config, linear_method)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -246,17 +234,19 @@ class GemmaModel(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
# Normalize the embedding by sqrt(hidden_size)
|
# Normalize the embedding by sqrt(hidden_size)
|
||||||
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
hidden_states *= self.config.hidden_size**0.5
|
||||||
|
|
||||||
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i],
|
kv_caches[i],
|
||||||
input_metadata,
|
input_metadata,
|
||||||
|
residual,
|
||||||
)
|
)
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -321,6 +311,10 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
# Skip loading extra layer for lora models.
|
# Skip loading extra layer for lora models.
|
||||||
if "lm_head" in name:
|
if "lm_head" in name:
|
||||||
continue
|
continue
|
||||||
|
# GemmaRMSNorm is different from Llama's in that it multiplies
|
||||||
|
# (1 + weight) to the output, instead of just weight.
|
||||||
|
if "norm.weight" in name:
|
||||||
|
loaded_weight += 1.0
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
@ -329,5 +323,5 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
unloaded_params = params_dict.keys() - loaded_params
|
unloaded_params = params_dict.keys() - loaded_params
|
||||||
if unloaded_params:
|
if unloaded_params:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
"Some weights are not initialized from checkpoints: "
|
||||||
)
|
f"{unloaded_params}")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user