picotron/src/nn/layer_norm.py
2024-10-22 22:38:29 +00:00

43 lines
1.5 KiB
Python

import torch
from torch import nn
from flash_attn.ops.triton.layer_norm import layer_norm_fn
class TritonRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
self.register_parameter("bias", None)
def forward(
self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
return layer_norm_fn(
hidden_states,
self.weight,
None,
residual=residual,
eps=self.eps,
dropout_p=dropout_p,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=True,
return_dropout_mask=return_dropout_mask,
)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)