import torch from torch import nn from flash_attn.ops.triton.layer_norm import layer_norm_fn class TritonRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) self.register_parameter("bias", None) def forward( self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False ): return layer_norm_fn( hidden_states, self.weight, None, residual=residual, eps=self.eps, dropout_p=dropout_p, prenorm=prenorm, residual_in_fp32=residual_in_fp32, is_rms_norm=True, return_dropout_mask=return_dropout_mask, ) class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype)