Implement GatedMlp

This commit is contained in:
Tri Dao 2023-04-18 03:37:14 -07:00
parent ac3b684cdb
commit b630aef53f
3 changed files with 57 additions and 25 deletions

View File

@ -16,8 +16,9 @@ from transformers import GPT2Config
from einops import rearrange
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
from flash_attn.modules.mlp import Mlp, GatedMlp, FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
@ -43,10 +44,9 @@ except ImportError:
dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense, sqrelu_fwd
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
FusedDenseSqreluDense = None
sqrelu_fwd = None
logger = logging.getLogger(__name__)
@ -90,7 +90,6 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
fused_mlp = getattr(config, 'fused_mlp', False)
if fused_mlp:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
@ -102,17 +101,25 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if process_group is not None:
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
if not fused_mlp and not fused_dense_sqrelu_dense:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True)
elif config.activation_function == 'sqrelu':
assert sqrelu_fwd is not None, 'sqrelu_fwd is not implemented'
activation = sqrelu_fwd
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
'sqrelu', 'glu', 'swiglu', 'geglu']
if config.activation_function in ['glu', 'swiglu', 'geglu']:
activation = (F.sigmoid if config.activation_function == 'glu'
else (F.silu if config.activation_function == 'swiglu'
else F.gelu))
mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation,
**factory_kwargs)
else:
approximate = ('tanh' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
activation=partial(F.gelu, approximate=approximate)
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True)
elif config.activation_function == 'sqrelu':
activation = sqrelu_fwd
else:
approximate = ('tanh' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
activation=partial(F.gelu, approximate=approximate)
mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation,
**factory_kwargs)
else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
@ -128,12 +135,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation,
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
checkpoint_lvl=mlp_checkpoint_lvl,
**parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense:
assert FusedDenseSqreluDense is not None
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner,
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
else:
raise RuntimeError('MLP type not supported')
@ -252,7 +259,7 @@ class GPTModel(GPTPreTrainedModel):
self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
'relu', 'sqrelu']
'relu', 'sqrelu', 'glu', 'swiglu', 'geglu']
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)

View File

@ -28,3 +28,28 @@ class Mlp(nn.Module):
y = self.activation(y)
y = self.fc2(y)
return y if not self.return_residual else (y, x)
class GatedMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
multiple_of=128, return_residual=False, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or int(8 * in_features / 3)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, 2 * hidden_features, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
def forward(self, x):
y = self.fc1(x)
if self.activation == F.sigmoid: # Special case for GLU
y = F.glu(y, dim=-1)
else:
y, gate = y.chunk(2, dim=-1)
y = y * self.activation(gate)
y = self.fc2(y)
return y if not self.return_residual else (y, x)

View File

@ -404,7 +404,7 @@ def fused_mlp_func(
class FusedMLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True,
bias2=True, activation='gelu_approx', return_residual=False,
checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
"""
@ -432,8 +432,8 @@ class FusedMLP(nn.Module):
assert activation in ['gelu_approx', 'relu', 'sqrelu']
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if out_features is None:
out_features = in_features
out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4
self.activation = activation
self.return_residual = return_residual
self.checkpoint_lvl = checkpoint_lvl
@ -469,9 +469,9 @@ class FusedMLP(nn.Module):
class ParallelFusedMLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features=None, activation='gelu_approx',
process_group: ProcessGroup = None, bias1=True, bias2=True,
sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
def __init__(self, in_features, hidden_features=None, out_features=None,
activation='gelu_approx', process_group: ProcessGroup = None,
bias1=True, bias2=True, sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
device=None, dtype=None):
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
@ -494,8 +494,8 @@ class ParallelFusedMLP(nn.Module):
assert process_group is not None
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if out_features is None:
out_features = in_features
out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4
self.activation = activation
self.process_group = process_group
self.sequence_parallel = sequence_parallel