[MLP] Add ParallelMLP
This commit is contained in:
parent
b3177dfaf6
commit
75e334d407
@ -18,7 +18,7 @@ from einops import rearrange
|
||||
|
||||
from flash_attn.ops.activations import sqrelu_fwd
|
||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import Mlp, GatedMlp, FusedMLP, ParallelFusedMLP
|
||||
from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP
|
||||
from flash_attn.modules.block import Block, ParallelBlock
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
||||
@ -112,10 +112,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
|
||||
'supports approximate activation_function sqrelu')
|
||||
assert not (fused_dense_sqrelu_dense and fused_mlp)
|
||||
if process_group is not None:
|
||||
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
|
||||
if not fused_mlp and not fused_dense_sqrelu_dense:
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
|
||||
'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||
if config.activation_function in ['glu', 'swiglu', 'geglu']:
|
||||
activation = (F.sigmoid if config.activation_function == 'glu'
|
||||
@ -132,8 +130,13 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
approximate = ('tanh' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
||||
activation=partial(F.gelu, approximate=approximate)
|
||||
mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation,
|
||||
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
|
||||
mlp_cls = Mlp if process_group is None else ParallelMLP
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
|
||||
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
|
||||
**parallel_kwargs, **factory_kwargs)
|
||||
else:
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
|
||||
@ -288,6 +288,8 @@ class ParallelBlock(nn.Module):
|
||||
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
|
||||
residual.
|
||||
"""
|
||||
# TODO: Ideally we should only do the allgather / allreduce once for
|
||||
# the Linear to MLP & Attention
|
||||
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
|
||||
if isinstance(self.norm1, RMSNorm)
|
||||
else dropout_add_layer_norm_parallel_residual)
|
||||
|
||||
@ -3,6 +3,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||
except ImportError:
|
||||
ColumnParallelLinear, RowParallelLinear = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
|
||||
@ -30,6 +36,30 @@ class Mlp(nn.Module):
|
||||
return y if not self.return_residual else (y, x)
|
||||
|
||||
|
||||
class ParallelMLP(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
|
||||
process_group: ProcessGroup = None, sequence_parallel=True,
|
||||
bias1=True, bias2=True, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
||||
assert RowParallelLinear is not None, "Need to install fused_dense"
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
self.activation = activation
|
||||
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc1(x)
|
||||
y = self.activation(y)
|
||||
y = self.fc2(y)
|
||||
return y
|
||||
|
||||
|
||||
class GatedMlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user