43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
import torch
|
|
from torch import nn
|
|
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
|
|
|
class TritonRMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.register_parameter("bias", None)
|
|
|
|
def forward(
|
|
self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
|
|
):
|
|
return layer_norm_fn(
|
|
hidden_states,
|
|
self.weight,
|
|
None,
|
|
residual=residual,
|
|
eps=self.eps,
|
|
dropout_p=dropout_p,
|
|
prenorm=prenorm,
|
|
residual_in_fp32=residual_in_fp32,
|
|
is_rms_norm=True,
|
|
return_dropout_mask=return_dropout_mask,
|
|
)
|
|
|
|
class LlamaRMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-5):
|
|
"""
|
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
"""
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype) |