Implement GatedMlp
This commit is contained in:
parent
ac3b684cdb
commit
b630aef53f
@ -16,8 +16,9 @@ from transformers import GPT2Config
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.ops.activations import sqrelu_fwd
|
||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
|
||||
from flash_attn.modules.mlp import Mlp, GatedMlp, FusedMLP, ParallelFusedMLP
|
||||
from flash_attn.modules.block import Block, ParallelBlock
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
||||
@ -43,10 +44,9 @@ except ImportError:
|
||||
dropout_add_layer_norm_parallel_residual = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense, sqrelu_fwd
|
||||
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
|
||||
except ImportError:
|
||||
FusedDenseSqreluDense = None
|
||||
sqrelu_fwd = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -90,7 +90,6 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
|
||||
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||
if fused_mlp:
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
|
||||
@ -102,17 +101,25 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
if process_group is not None:
|
||||
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
|
||||
if not fused_mlp and not fused_dense_sqrelu_dense:
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
|
||||
if config.activation_function == 'relu':
|
||||
activation = partial(F.relu, inplace=True)
|
||||
elif config.activation_function == 'sqrelu':
|
||||
assert sqrelu_fwd is not None, 'sqrelu_fwd is not implemented'
|
||||
activation = sqrelu_fwd
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
|
||||
'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||
if config.activation_function in ['glu', 'swiglu', 'geglu']:
|
||||
activation = (F.sigmoid if config.activation_function == 'glu'
|
||||
else (F.silu if config.activation_function == 'swiglu'
|
||||
else F.gelu))
|
||||
mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation,
|
||||
**factory_kwargs)
|
||||
else:
|
||||
approximate = ('tanh' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
||||
activation=partial(F.gelu, approximate=approximate)
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
|
||||
if config.activation_function == 'relu':
|
||||
activation = partial(F.relu, inplace=True)
|
||||
elif config.activation_function == 'sqrelu':
|
||||
activation = sqrelu_fwd
|
||||
else:
|
||||
approximate = ('tanh' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
||||
activation=partial(F.gelu, approximate=approximate)
|
||||
mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation,
|
||||
**factory_kwargs)
|
||||
else:
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
@ -128,12 +135,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation,
|
||||
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl,
|
||||
**parallel_kwargs, **factory_kwargs)
|
||||
elif fused_dense_sqrelu_dense:
|
||||
assert FusedDenseSqreluDense is not None
|
||||
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
|
||||
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
|
||||
else:
|
||||
raise RuntimeError('MLP type not supported')
|
||||
@ -252,7 +259,7 @@ class GPTModel(GPTPreTrainedModel):
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
|
||||
'relu', 'sqrelu']
|
||||
'relu', 'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
|
||||
@ -28,3 +28,28 @@ class Mlp(nn.Module):
|
||||
y = self.activation(y)
|
||||
y = self.fc2(y)
|
||||
return y if not self.return_residual else (y, x)
|
||||
|
||||
|
||||
class GatedMlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
|
||||
multiple_of=128, return_residual=False, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or int(8 * in_features / 3)
|
||||
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
||||
self.return_residual = return_residual
|
||||
self.fc1 = nn.Linear(in_features, 2 * hidden_features, **factory_kwargs)
|
||||
self.activation = activation
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc1(x)
|
||||
if self.activation == F.sigmoid: # Special case for GLU
|
||||
y = F.glu(y, dim=-1)
|
||||
else:
|
||||
y, gate = y.chunk(2, dim=-1)
|
||||
y = y * self.activation(gate)
|
||||
y = self.fc2(y)
|
||||
return y if not self.return_residual else (y, x)
|
||||
|
||||
@ -404,7 +404,7 @@ def fused_mlp_func(
|
||||
|
||||
class FusedMLP(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True,
|
||||
bias2=True, activation='gelu_approx', return_residual=False,
|
||||
checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
|
||||
"""
|
||||
@ -432,8 +432,8 @@ class FusedMLP(nn.Module):
|
||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = in_features
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
self.activation = activation
|
||||
self.return_residual = return_residual
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
@ -469,9 +469,9 @@ class FusedMLP(nn.Module):
|
||||
|
||||
class ParallelFusedMLP(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features, out_features=None, activation='gelu_approx',
|
||||
process_group: ProcessGroup = None, bias1=True, bias2=True,
|
||||
sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None,
|
||||
activation='gelu_approx', process_group: ProcessGroup = None,
|
||||
bias1=True, bias2=True, sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
|
||||
device=None, dtype=None):
|
||||
"""
|
||||
process_group is required. We're doing Tensor Parallel with sequence parallelism:
|
||||
@ -494,8 +494,8 @@ class ParallelFusedMLP(nn.Module):
|
||||
assert process_group is not None
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = in_features
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
self.activation = activation
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
|
||||
Loading…
Reference in New Issue
Block a user