[MLP] Add ParallelMLP

This commit is contained in:
Tri Dao 2023-07-22 23:45:51 -07:00
parent b3177dfaf6
commit 75e334d407
3 changed files with 41 additions and 6 deletions

View File

@ -18,7 +18,7 @@ from einops import rearrange
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, GatedMlp, FusedMLP, ParallelFusedMLP
from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
@ -112,10 +112,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
'supports approximate activation_function sqrelu')
assert not (fused_dense_sqrelu_dense and fused_mlp)
if process_group is not None:
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
if not fused_mlp and not fused_dense_sqrelu_dense:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
'sqrelu', 'glu', 'swiglu', 'geglu']
if config.activation_function in ['glu', 'swiglu', 'geglu']:
activation = (F.sigmoid if config.activation_function == 'glu'
@ -132,8 +130,13 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
approximate = ('tanh' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
activation=partial(F.gelu, approximate=approximate)
mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
mlp_cls = Mlp if process_group is None else ParallelMLP
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
**parallel_kwargs, **factory_kwargs)
else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer

View File

@ -288,6 +288,8 @@ class ParallelBlock(nn.Module):
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
if isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm_parallel_residual)

View File

@ -3,6 +3,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
except ImportError:
ColumnParallelLinear, RowParallelLinear = None, None
try:
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
@ -30,6 +36,30 @@ class Mlp(nn.Module):
return y if not self.return_residual else (y, x)
class ParallelMLP(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
process_group: ProcessGroup = None, sequence_parallel=True,
bias1=True, bias2=True, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
assert ColumnParallelLinear is not None, "Need to install fused_dense"
assert RowParallelLinear is not None, "Need to install fused_dense"
out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
sequence_parallel=sequence_parallel, **factory_kwargs)
self.activation = activation
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
sequence_parallel=sequence_parallel, **factory_kwargs)
def forward(self, x):
y = self.fc1(x)
y = self.activation(y)
y = self.fc2(y)
return y
class GatedMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,