2024-01-05 16:31:17 +08:00
|
|
|
# Copyright (c) 2024, Tri Dao.
|
2022-11-14 14:06:44 +08:00
|
|
|
|
|
|
|
|
from functools import partial
|
2023-08-19 05:22:11 +08:00
|
|
|
from typing import Optional
|
2022-11-14 14:06:44 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch import Tensor
|
2023-07-22 07:34:19 +08:00
|
|
|
from torchvision.ops import StochasticDepth
|
2022-11-14 14:06:44 +08:00
|
|
|
|
|
|
|
|
from flash_attn.modules.mha import MHA
|
|
|
|
|
from flash_attn.modules.mlp import Mlp
|
|
|
|
|
|
|
|
|
|
try:
|
2024-01-05 16:31:17 +08:00
|
|
|
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
2022-11-14 14:06:44 +08:00
|
|
|
except ImportError:
|
2024-01-05 16:31:17 +08:00
|
|
|
layer_norm_fn, RMSNorm = None, None
|
2023-04-19 12:43:37 +08:00
|
|
|
|
2022-11-14 14:06:44 +08:00
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
2023-08-19 05:22:11 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
dim,
|
|
|
|
|
mixer_cls=None,
|
|
|
|
|
mlp_cls=None,
|
|
|
|
|
norm_cls=nn.LayerNorm,
|
|
|
|
|
dropout_cls=nn.Dropout,
|
|
|
|
|
prenorm=True,
|
|
|
|
|
resid_dropout1=0.0,
|
|
|
|
|
resid_dropout2=0.0,
|
|
|
|
|
drop_path1=0.0,
|
|
|
|
|
drop_path2=0.0,
|
|
|
|
|
fused_dropout_add_ln=False,
|
|
|
|
|
return_residual=False,
|
|
|
|
|
residual_in_fp32=False,
|
|
|
|
|
sequence_parallel=False,
|
|
|
|
|
mark_shared_params=False,
|
|
|
|
|
):
|
2022-12-19 13:47:27 +08:00
|
|
|
"""
|
2023-01-16 14:14:31 +08:00
|
|
|
For prenorm=True, this Block has a slightly different structure compared to a regular
|
|
|
|
|
prenorm Transformer block.
|
|
|
|
|
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
|
|
|
|
|
[Ref: https://arxiv.org/abs/2002.04745]
|
|
|
|
|
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
|
|
|
|
|
the hidden_states (output of the MLP) and the residual.
|
|
|
|
|
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
|
|
|
|
The residual needs to be provided (except for the very first block).
|
|
|
|
|
|
|
|
|
|
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
|
|
|
|
|
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
|
|
|
|
|
|
2022-12-19 13:47:27 +08:00
|
|
|
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
|
|
|
|
|
This is for performance reason: for post-norm architecture, returning the input allows us
|
|
|
|
|
to fuse the backward of nn.Linear with the residual connection.
|
|
|
|
|
"""
|
2022-11-14 14:06:44 +08:00
|
|
|
super().__init__()
|
|
|
|
|
self.prenorm = prenorm
|
|
|
|
|
self.fused_dropout_add_ln = fused_dropout_add_ln
|
2022-12-19 13:47:27 +08:00
|
|
|
self.return_residual = return_residual
|
2023-01-16 14:14:31 +08:00
|
|
|
self.residual_in_fp32 = residual_in_fp32
|
|
|
|
|
if self.residual_in_fp32:
|
2023-08-19 05:22:11 +08:00
|
|
|
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
|
2022-11-14 14:06:44 +08:00
|
|
|
if mixer_cls is None:
|
|
|
|
|
mixer_cls = partial(MHA, num_heads=dim // 64)
|
|
|
|
|
if mlp_cls is None:
|
|
|
|
|
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
|
|
|
|
self.mixer = mixer_cls(dim)
|
2023-01-16 14:14:31 +08:00
|
|
|
self.dropout1 = dropout_cls(resid_dropout1)
|
2023-08-19 05:22:11 +08:00
|
|
|
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
|
2022-11-14 14:06:44 +08:00
|
|
|
self.norm1 = norm_cls(dim)
|
|
|
|
|
self.mlp = mlp_cls(dim)
|
|
|
|
|
if not isinstance(self.mlp, nn.Identity):
|
2023-01-16 14:14:31 +08:00
|
|
|
self.dropout2 = dropout_cls(resid_dropout2)
|
2023-08-19 05:22:11 +08:00
|
|
|
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
|
2022-11-14 14:06:44 +08:00
|
|
|
self.norm2 = norm_cls(dim)
|
|
|
|
|
|
|
|
|
|
if self.fused_dropout_add_ln:
|
2024-01-05 16:31:17 +08:00
|
|
|
assert layer_norm_fn is not None, "Triton is not installed"
|
2023-08-19 05:22:11 +08:00
|
|
|
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
|
|
|
|
self.dropout1, nn.Dropout
|
|
|
|
|
)
|
2022-11-14 14:06:44 +08:00
|
|
|
|
2023-01-08 05:45:22 +08:00
|
|
|
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
|
|
|
|
# then the input to each worker in the tensor parallel group will be different.
|
|
|
|
|
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
|
|
|
|
# For now this is not an issue because we always use sequence_parallel=True during training
|
|
|
|
|
# and only use sequence_parallel=False during inference.
|
|
|
|
|
|
2022-12-26 03:40:14 +08:00
|
|
|
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
|
|
|
|
if sequence_parallel:
|
|
|
|
|
for p in self.norm1.parameters():
|
|
|
|
|
p._sequence_parallel = True
|
2023-08-19 05:22:11 +08:00
|
|
|
if hasattr(self, "norm2"):
|
2022-12-26 03:40:14 +08:00
|
|
|
for p in self.norm2.parameters():
|
|
|
|
|
p._sequence_parallel = True
|
2023-01-08 05:45:22 +08:00
|
|
|
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
|
|
|
|
if mark_shared_params:
|
|
|
|
|
for p in self.norm1.parameters():
|
|
|
|
|
p._shared_params = True
|
2023-08-19 05:22:11 +08:00
|
|
|
if hasattr(self, "norm2"):
|
2023-01-08 05:45:22 +08:00
|
|
|
for p in self.norm2.parameters():
|
|
|
|
|
p._shared_params = True
|
2022-12-26 03:40:14 +08:00
|
|
|
|
2023-04-21 09:15:12 +08:00
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
|
|
|
|
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: Tensor,
|
|
|
|
|
residual: Optional[Tensor] = None,
|
|
|
|
|
mixer_subset=None,
|
|
|
|
|
mixer_kwargs=None,
|
|
|
|
|
):
|
2022-11-14 14:06:44 +08:00
|
|
|
r"""Pass the input through the encoder layer.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
hidden_states: the sequence to the encoder layer (required).
|
2023-01-16 14:58:56 +08:00
|
|
|
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
2023-01-16 17:20:04 +08:00
|
|
|
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
|
|
|
|
before applying the query projection. Useful for e.g., ViT where we only care
|
|
|
|
|
about the CLS token in the last layer.
|
2022-11-14 14:06:44 +08:00
|
|
|
"""
|
|
|
|
|
if self.prenorm:
|
|
|
|
|
if not self.fused_dropout_add_ln:
|
2023-01-16 14:14:31 +08:00
|
|
|
dropped = self.drop_path1(self.dropout1(hidden_states))
|
|
|
|
|
residual = (dropped + residual) if residual is not None else dropped
|
2022-11-14 14:06:44 +08:00
|
|
|
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
2023-01-16 14:14:31 +08:00
|
|
|
if self.residual_in_fp32:
|
|
|
|
|
residual = residual.to(torch.float32)
|
2022-11-14 14:06:44 +08:00
|
|
|
else:
|
2023-07-22 07:34:19 +08:00
|
|
|
if self.drop_path1.p == 0 or not self.training:
|
|
|
|
|
rowscale1 = None
|
|
|
|
|
else:
|
2023-08-19 05:22:11 +08:00
|
|
|
rowscale1 = self.drop_path1(
|
|
|
|
|
torch.ones(
|
|
|
|
|
hidden_states.shape[:-1],
|
|
|
|
|
device=hidden_states.device,
|
|
|
|
|
dtype=hidden_states.dtype,
|
|
|
|
|
)
|
2023-07-22 07:34:19 +08:00
|
|
|
)
|
2024-01-05 16:31:17 +08:00
|
|
|
hidden_states, residual = layer_norm_fn(
|
2023-08-19 05:22:11 +08:00
|
|
|
hidden_states,
|
|
|
|
|
self.norm1.weight,
|
|
|
|
|
self.norm1.bias,
|
2024-01-05 16:31:17 +08:00
|
|
|
residual=residual,
|
|
|
|
|
eps=self.norm1.eps,
|
|
|
|
|
dropout_p=self.dropout1.p if self.training else 0.0,
|
2023-08-19 05:22:11 +08:00
|
|
|
rowscale=rowscale1,
|
|
|
|
|
prenorm=True,
|
|
|
|
|
residual_in_fp32=self.residual_in_fp32,
|
2024-01-05 16:31:17 +08:00
|
|
|
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
2022-11-14 14:06:44 +08:00
|
|
|
)
|
2023-01-16 17:20:04 +08:00
|
|
|
if mixer_kwargs is None:
|
|
|
|
|
mixer_kwargs = {}
|
2023-01-18 10:12:27 +08:00
|
|
|
if mixer_subset is not None:
|
2023-08-19 05:22:11 +08:00
|
|
|
mixer_kwargs["mixer_subset"] = mixer_subset
|
2023-01-16 17:20:04 +08:00
|
|
|
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
|
|
|
|
if mixer_subset is not None:
|
|
|
|
|
residual = residual[:, mixer_subset]
|
2022-11-14 14:06:44 +08:00
|
|
|
if not isinstance(self.mlp, nn.Identity):
|
|
|
|
|
if not self.fused_dropout_add_ln:
|
2023-01-16 14:14:31 +08:00
|
|
|
dropped = self.drop_path2(self.dropout2(hidden_states))
|
|
|
|
|
residual = (dropped + residual) if residual is not None else dropped
|
2022-11-14 14:06:44 +08:00
|
|
|
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
2023-01-16 14:14:31 +08:00
|
|
|
if self.residual_in_fp32:
|
|
|
|
|
residual = residual.to(torch.float32)
|
2022-11-14 14:06:44 +08:00
|
|
|
else:
|
2023-07-22 07:34:19 +08:00
|
|
|
if self.drop_path2.p == 0 or not self.training:
|
|
|
|
|
rowscale2 = None
|
|
|
|
|
else:
|
2023-08-19 05:22:11 +08:00
|
|
|
rowscale2 = self.drop_path2(
|
|
|
|
|
torch.ones(
|
|
|
|
|
hidden_states.shape[:-1],
|
|
|
|
|
device=hidden_states.device,
|
|
|
|
|
dtype=hidden_states.dtype,
|
|
|
|
|
)
|
2023-07-22 07:34:19 +08:00
|
|
|
)
|
2024-01-05 16:31:17 +08:00
|
|
|
hidden_states, residual = layer_norm_fn(
|
2023-08-19 05:22:11 +08:00
|
|
|
hidden_states,
|
|
|
|
|
self.norm2.weight,
|
|
|
|
|
self.norm2.bias,
|
2024-01-05 16:31:17 +08:00
|
|
|
residual=residual,
|
|
|
|
|
eps=self.norm2.eps,
|
|
|
|
|
dropout_p=self.dropout2.p if self.training else 0.0,
|
2023-08-19 05:22:11 +08:00
|
|
|
rowscale=rowscale2,
|
|
|
|
|
prenorm=True,
|
|
|
|
|
residual_in_fp32=self.residual_in_fp32,
|
2024-01-05 16:31:17 +08:00
|
|
|
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
2022-11-14 14:06:44 +08:00
|
|
|
)
|
2023-01-16 14:14:31 +08:00
|
|
|
hidden_states = self.mlp(hidden_states)
|
2022-11-14 14:06:44 +08:00
|
|
|
return hidden_states, residual
|
|
|
|
|
else:
|
|
|
|
|
assert residual is None
|
2022-12-19 13:47:27 +08:00
|
|
|
mixer_out = self.mixer(
|
|
|
|
|
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
|
|
|
|
|
)
|
|
|
|
|
if self.return_residual: # mixer out is actually a pair here
|
|
|
|
|
mixer_out, hidden_states = mixer_out
|
2022-11-14 14:06:44 +08:00
|
|
|
if not self.fused_dropout_add_ln:
|
2023-08-19 05:22:11 +08:00
|
|
|
hidden_states = self.norm1(
|
|
|
|
|
(self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
|
|
|
|
|
dtype=self.norm1.weight.dtype
|
|
|
|
|
)
|
|
|
|
|
)
|
2022-11-14 14:06:44 +08:00
|
|
|
else:
|
|
|
|
|
if self.drop_path1.p == 0 or not self.training:
|
|
|
|
|
rowscale1 = None
|
|
|
|
|
else:
|
2023-08-19 05:22:11 +08:00
|
|
|
rowscale1 = self.drop_path1(
|
|
|
|
|
torch.ones(
|
|
|
|
|
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
|
|
|
|
|
)
|
2022-11-14 14:06:44 +08:00
|
|
|
)
|
2024-01-05 16:31:17 +08:00
|
|
|
hidden_states = layer_norm_fn(
|
2023-08-19 05:22:11 +08:00
|
|
|
mixer_out,
|
|
|
|
|
self.norm1.weight,
|
|
|
|
|
self.norm1.bias,
|
2024-01-05 16:31:17 +08:00
|
|
|
residual=hidden_states,
|
|
|
|
|
eps=self.norm1.eps,
|
|
|
|
|
dropout_p=self.dropout1.p if self.training else 0.0,
|
2023-08-19 05:22:11 +08:00
|
|
|
rowscale=rowscale1,
|
|
|
|
|
prenorm=False,
|
2024-01-05 16:31:17 +08:00
|
|
|
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
2022-11-14 14:06:44 +08:00
|
|
|
)
|
|
|
|
|
if not isinstance(self.mlp, nn.Identity):
|
|
|
|
|
mlp_out = self.mlp(hidden_states)
|
2022-12-19 13:47:27 +08:00
|
|
|
if self.return_residual: # mlp out is actually a pair here
|
|
|
|
|
mlp_out, hidden_states = mlp_out
|
2022-11-14 14:06:44 +08:00
|
|
|
if not self.fused_dropout_add_ln:
|
2023-08-19 05:22:11 +08:00
|
|
|
hidden_states = self.norm2(
|
|
|
|
|
(self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
|
|
|
|
|
dtype=self.norm2.weight.dtype
|
|
|
|
|
)
|
|
|
|
|
)
|
2022-11-14 14:06:44 +08:00
|
|
|
else:
|
|
|
|
|
if self.drop_path2.p == 0 or not self.training:
|
|
|
|
|
rowscale2 = None
|
|
|
|
|
else:
|
2023-08-19 05:22:11 +08:00
|
|
|
rowscale2 = self.drop_path2(
|
|
|
|
|
torch.ones(
|
|
|
|
|
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
|
|
|
|
|
)
|
2022-11-14 14:06:44 +08:00
|
|
|
)
|
2024-01-05 16:31:17 +08:00
|
|
|
hidden_states = layer_norm_fn(
|
2023-08-19 05:22:11 +08:00
|
|
|
mlp_out,
|
|
|
|
|
self.norm2.weight,
|
|
|
|
|
self.norm2.bias,
|
2024-01-05 16:31:17 +08:00
|
|
|
residual=hidden_states,
|
|
|
|
|
eps=self.norm2.eps,
|
|
|
|
|
dropout_p=self.dropout2.p if self.training else 0.0,
|
2023-08-19 05:22:11 +08:00
|
|
|
rowscale=rowscale2,
|
|
|
|
|
prenorm=False,
|
2024-01-05 16:31:17 +08:00
|
|
|
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
2022-11-14 14:06:44 +08:00
|
|
|
)
|
|
|
|
|
return hidden_states
|
2023-03-23 07:16:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelBlock(nn.Module):
|
|
|
|
|
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
|
|
|
|
|
and PaLM.
|
|
|
|
|
"""
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
dim,
|
|
|
|
|
mixer_cls=None,
|
|
|
|
|
mlp_cls=None,
|
|
|
|
|
norm_cls=nn.LayerNorm,
|
|
|
|
|
dropout_cls=nn.Dropout,
|
|
|
|
|
resid_dropout1=0.0,
|
|
|
|
|
resid_dropout2=0.0,
|
|
|
|
|
tied_norm=False,
|
|
|
|
|
fused_dropout_add_ln=False,
|
|
|
|
|
residual_in_fp32=False,
|
|
|
|
|
sequence_parallel=False,
|
|
|
|
|
mark_shared_params=False,
|
|
|
|
|
):
|
2023-03-23 07:16:58 +08:00
|
|
|
"""
|
|
|
|
|
This Block has a slightly different structure compared to a regular
|
|
|
|
|
prenorm Transformer block.
|
|
|
|
|
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
|
|
|
|
|
[Ref: https://arxiv.org/abs/2002.04745]
|
|
|
|
|
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
|
|
|
|
|
the hidden_states (output1 of the MHA / MLP) and the residual.
|
|
|
|
|
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
|
|
|
|
The residual needs to be provided (except for the very first block).
|
|
|
|
|
"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.tied_norm = tied_norm
|
|
|
|
|
self.fused_dropout_add_ln = fused_dropout_add_ln
|
|
|
|
|
self.residual_in_fp32 = residual_in_fp32
|
|
|
|
|
if mixer_cls is None:
|
|
|
|
|
mixer_cls = partial(MHA, num_heads=dim // 64)
|
|
|
|
|
if mlp_cls is None:
|
|
|
|
|
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
|
|
|
|
self.mixer = mixer_cls(dim)
|
|
|
|
|
self.dropout1 = dropout_cls(resid_dropout1)
|
|
|
|
|
self.norm1 = norm_cls(dim)
|
|
|
|
|
self.mlp = mlp_cls(dim)
|
|
|
|
|
self.dropout2 = dropout_cls(resid_dropout2)
|
|
|
|
|
if not self.tied_norm:
|
|
|
|
|
self.norm2 = norm_cls(dim)
|
|
|
|
|
|
|
|
|
|
if self.fused_dropout_add_ln:
|
2024-01-05 16:31:17 +08:00
|
|
|
assert layer_norm_fn is not None, "Triton is not installed"
|
2023-08-19 05:22:11 +08:00
|
|
|
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
|
|
|
|
self.dropout1, nn.Dropout
|
|
|
|
|
)
|
2023-03-23 07:16:58 +08:00
|
|
|
|
|
|
|
|
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
|
|
|
|
# then the input to each worker in the tensor parallel group will be different.
|
|
|
|
|
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
|
|
|
|
# For now this is not an issue because we always use sequence_parallel=True during training
|
|
|
|
|
# and only use sequence_parallel=False during inference.
|
|
|
|
|
|
|
|
|
|
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
|
|
|
|
if sequence_parallel:
|
|
|
|
|
for p in self.norm1.parameters():
|
|
|
|
|
p._sequence_parallel = True
|
2023-08-19 05:22:11 +08:00
|
|
|
if hasattr(self, "norm2"):
|
2023-03-23 07:16:58 +08:00
|
|
|
for p in self.norm2.parameters():
|
|
|
|
|
p._sequence_parallel = True
|
|
|
|
|
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
|
|
|
|
if mark_shared_params:
|
|
|
|
|
for p in self.norm1.parameters():
|
|
|
|
|
p._shared_params = True
|
2023-08-19 05:22:11 +08:00
|
|
|
if hasattr(self, "norm2"):
|
2023-03-23 07:16:58 +08:00
|
|
|
for p in self.norm2.parameters():
|
|
|
|
|
p._shared_params = True
|
|
|
|
|
|
2023-07-22 08:29:10 +08:00
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
|
|
|
|
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states1: Tensor,
|
|
|
|
|
hidden_states2: Optional[Tensor] = None,
|
|
|
|
|
residual: Optional[Tensor] = None,
|
|
|
|
|
mixer_kwargs=None,
|
|
|
|
|
):
|
2023-03-23 07:16:58 +08:00
|
|
|
r"""Pass the input through the encoder layer.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
hidden_states1: the output of the previous attention (mixer) or embedding layer.
|
|
|
|
|
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
|
|
|
|
|
residual.
|
|
|
|
|
"""
|
2023-07-23 14:45:51 +08:00
|
|
|
# TODO: Ideally we should only do the allgather / allreduce once for
|
|
|
|
|
# the Linear to MLP & Attention
|
2023-03-30 06:59:36 +08:00
|
|
|
if not self.fused_dropout_add_ln:
|
|
|
|
|
dropped1 = self.dropout1(hidden_states1)
|
|
|
|
|
# For the very 1st block, we only want 1 dropout, not two different dropouts
|
|
|
|
|
if hidden_states2 is not None:
|
|
|
|
|
dropped2 = self.dropout2(hidden_states2)
|
2023-08-19 05:22:11 +08:00
|
|
|
residual = (
|
|
|
|
|
(residual + dropped1 + dropped2)
|
|
|
|
|
if residual is not None
|
|
|
|
|
else dropped1 + dropped2
|
|
|
|
|
)
|
2023-03-30 06:59:36 +08:00
|
|
|
else:
|
|
|
|
|
residual = (residual + dropped1) if residual is not None else dropped1
|
|
|
|
|
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
2023-08-19 05:22:11 +08:00
|
|
|
hidden_states2 = (
|
|
|
|
|
self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
|
|
|
|
if not self.tied_norm
|
|
|
|
|
else hidden_states1
|
|
|
|
|
)
|
2023-03-30 06:59:36 +08:00
|
|
|
if self.residual_in_fp32:
|
|
|
|
|
residual = residual.to(torch.float32)
|
2023-03-23 07:16:58 +08:00
|
|
|
else:
|
2023-08-19 05:22:11 +08:00
|
|
|
weight2, bias2 = (
|
|
|
|
|
(self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
|
|
|
|
|
)
|
2024-01-05 16:31:17 +08:00
|
|
|
hidden_states1, *rest, residual = layer_norm_fn(
|
2023-08-19 05:22:11 +08:00
|
|
|
hidden_states1,
|
|
|
|
|
self.norm1.weight,
|
|
|
|
|
self.norm1.bias,
|
2024-01-05 16:31:17 +08:00
|
|
|
residual=residual,
|
|
|
|
|
x1=hidden_states2,
|
|
|
|
|
weight1=weight2,
|
|
|
|
|
bias1=bias2,
|
|
|
|
|
eps=self.norm1.eps,
|
|
|
|
|
dropout_p=self.dropout1.p if self.training else 0.0,
|
2023-08-19 05:22:11 +08:00
|
|
|
prenorm=True,
|
|
|
|
|
residual_in_fp32=self.residual_in_fp32,
|
2024-01-05 16:31:17 +08:00
|
|
|
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
2023-03-30 06:59:36 +08:00
|
|
|
)
|
|
|
|
|
if self.tied_norm:
|
|
|
|
|
hidden_states2 = hidden_states1
|
2024-01-05 16:31:17 +08:00
|
|
|
else:
|
|
|
|
|
hidden_states2, = rest
|
2023-03-23 07:16:58 +08:00
|
|
|
if mixer_kwargs is None:
|
|
|
|
|
mixer_kwargs = {}
|
|
|
|
|
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
|
|
|
|
hidden_states2 = self.mlp(hidden_states2)
|
|
|
|
|
return hidden_states1, hidden_states2, residual
|