[MLP] Add ParallelMLP
This commit is contained in:
parent
b3177dfaf6
commit
75e334d407
@ -18,7 +18,7 @@ from einops import rearrange
|
|||||||
|
|
||||||
from flash_attn.ops.activations import sqrelu_fwd
|
from flash_attn.ops.activations import sqrelu_fwd
|
||||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||||
from flash_attn.modules.mlp import Mlp, GatedMlp, FusedMLP, ParallelFusedMLP
|
from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP
|
||||||
from flash_attn.modules.block import Block, ParallelBlock
|
from flash_attn.modules.block import Block, ParallelBlock
|
||||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||||
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
||||||
@ -112,10 +112,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
|
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
|
||||||
'supports approximate activation_function sqrelu')
|
'supports approximate activation_function sqrelu')
|
||||||
assert not (fused_dense_sqrelu_dense and fused_mlp)
|
assert not (fused_dense_sqrelu_dense and fused_mlp)
|
||||||
if process_group is not None:
|
|
||||||
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
|
|
||||||
if not fused_mlp and not fused_dense_sqrelu_dense:
|
if not fused_mlp and not fused_dense_sqrelu_dense:
|
||||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
|
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
|
||||||
'sqrelu', 'glu', 'swiglu', 'geglu']
|
'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||||
if config.activation_function in ['glu', 'swiglu', 'geglu']:
|
if config.activation_function in ['glu', 'swiglu', 'geglu']:
|
||||||
activation = (F.sigmoid if config.activation_function == 'glu'
|
activation = (F.sigmoid if config.activation_function == 'glu'
|
||||||
@ -132,8 +130,13 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
approximate = ('tanh' if config.activation_function
|
approximate = ('tanh' if config.activation_function
|
||||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
||||||
activation=partial(F.gelu, approximate=approximate)
|
activation=partial(F.gelu, approximate=approximate)
|
||||||
mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation,
|
mlp_cls = Mlp if process_group is None else ParallelMLP
|
||||||
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
|
parallel_kwargs = ({'process_group': process_group,
|
||||||
|
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||||
|
if process_group is not None else {})
|
||||||
|
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
|
||||||
|
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
|
||||||
|
**parallel_kwargs, **factory_kwargs)
|
||||||
else:
|
else:
|
||||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||||
|
|||||||
@ -288,6 +288,8 @@ class ParallelBlock(nn.Module):
|
|||||||
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
|
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
|
||||||
residual.
|
residual.
|
||||||
"""
|
"""
|
||||||
|
# TODO: Ideally we should only do the allgather / allreduce once for
|
||||||
|
# the Linear to MLP & Attention
|
||||||
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
|
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
|
||||||
if isinstance(self.norm1, RMSNorm)
|
if isinstance(self.norm1, RMSNorm)
|
||||||
else dropout_add_layer_norm_parallel_residual)
|
else dropout_add_layer_norm_parallel_residual)
|
||||||
|
|||||||
@ -3,6 +3,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||||
|
except ImportError:
|
||||||
|
ColumnParallelLinear, RowParallelLinear = None, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
|
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
|
||||||
@ -30,6 +36,30 @@ class Mlp(nn.Module):
|
|||||||
return y if not self.return_residual else (y, x)
|
return y if not self.return_residual else (y, x)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
|
||||||
|
process_group: ProcessGroup = None, sequence_parallel=True,
|
||||||
|
bias1=True, bias2=True, device=None, dtype=None):
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__()
|
||||||
|
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
||||||
|
assert RowParallelLinear is not None, "Need to install fused_dense"
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features * 4
|
||||||
|
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
|
||||||
|
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||||
|
self.activation = activation
|
||||||
|
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
|
||||||
|
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.fc1(x)
|
||||||
|
y = self.activation(y)
|
||||||
|
y = self.fc2(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
class GatedMlp(nn.Module):
|
class GatedMlp(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
|
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user