Implement GatedMlp
This commit is contained in:
parent
ac3b684cdb
commit
b630aef53f
@ -16,8 +16,9 @@ from transformers import GPT2Config
|
|||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
from flash_attn.ops.activations import sqrelu_fwd
|
||||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||||
from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
|
from flash_attn.modules.mlp import Mlp, GatedMlp, FusedMLP, ParallelFusedMLP
|
||||||
from flash_attn.modules.block import Block, ParallelBlock
|
from flash_attn.modules.block import Block, ParallelBlock
|
||||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||||
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
||||||
@ -43,10 +44,9 @@ except ImportError:
|
|||||||
dropout_add_layer_norm_parallel_residual = None
|
dropout_add_layer_norm_parallel_residual = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense, sqrelu_fwd
|
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
|
||||||
except ImportError:
|
except ImportError:
|
||||||
FusedDenseSqreluDense = None
|
FusedDenseSqreluDense = None
|
||||||
sqrelu_fwd = None
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -90,7 +90,6 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
|||||||
|
|
||||||
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
|
||||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||||
if fused_mlp:
|
if fused_mlp:
|
||||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
|
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
|
||||||
@ -102,17 +101,25 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
if process_group is not None:
|
if process_group is not None:
|
||||||
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
|
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
|
||||||
if not fused_mlp and not fused_dense_sqrelu_dense:
|
if not fused_mlp and not fused_dense_sqrelu_dense:
|
||||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
|
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
|
||||||
if config.activation_function == 'relu':
|
'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||||
activation = partial(F.relu, inplace=True)
|
if config.activation_function in ['glu', 'swiglu', 'geglu']:
|
||||||
elif config.activation_function == 'sqrelu':
|
activation = (F.sigmoid if config.activation_function == 'glu'
|
||||||
assert sqrelu_fwd is not None, 'sqrelu_fwd is not implemented'
|
else (F.silu if config.activation_function == 'swiglu'
|
||||||
activation = sqrelu_fwd
|
else F.gelu))
|
||||||
|
mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation,
|
||||||
|
**factory_kwargs)
|
||||||
else:
|
else:
|
||||||
approximate = ('tanh' if config.activation_function
|
if config.activation_function == 'relu':
|
||||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
activation = partial(F.relu, inplace=True)
|
||||||
activation=partial(F.gelu, approximate=approximate)
|
elif config.activation_function == 'sqrelu':
|
||||||
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
|
activation = sqrelu_fwd
|
||||||
|
else:
|
||||||
|
approximate = ('tanh' if config.activation_function
|
||||||
|
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
||||||
|
activation=partial(F.gelu, approximate=approximate)
|
||||||
|
mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation,
|
||||||
|
**factory_kwargs)
|
||||||
else:
|
else:
|
||||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||||
@ -128,12 +135,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
parallel_kwargs = ({'process_group': process_group,
|
parallel_kwargs = ({'process_group': process_group,
|
||||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||||
if process_group is not None else {})
|
if process_group is not None else {})
|
||||||
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation,
|
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
|
||||||
checkpoint_lvl=mlp_checkpoint_lvl,
|
checkpoint_lvl=mlp_checkpoint_lvl,
|
||||||
**parallel_kwargs, **factory_kwargs)
|
**parallel_kwargs, **factory_kwargs)
|
||||||
elif fused_dense_sqrelu_dense:
|
elif fused_dense_sqrelu_dense:
|
||||||
assert FusedDenseSqreluDense is not None
|
assert FusedDenseSqreluDense is not None
|
||||||
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
|
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner,
|
||||||
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
|
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('MLP type not supported')
|
raise RuntimeError('MLP type not supported')
|
||||||
@ -252,7 +259,7 @@ class GPTModel(GPTPreTrainedModel):
|
|||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
|
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
|
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
|
||||||
'relu', 'sqrelu']
|
'relu', 'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||||
* pad_vocab_size_multiple)
|
* pad_vocab_size_multiple)
|
||||||
|
|||||||
@ -28,3 +28,28 @@ class Mlp(nn.Module):
|
|||||||
y = self.activation(y)
|
y = self.activation(y)
|
||||||
y = self.fc2(y)
|
y = self.fc2(y)
|
||||||
return y if not self.return_residual else (y, x)
|
return y if not self.return_residual else (y, x)
|
||||||
|
|
||||||
|
|
||||||
|
class GatedMlp(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
|
||||||
|
multiple_of=128, return_residual=False, device=None, dtype=None):
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or int(8 * in_features / 3)
|
||||||
|
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
||||||
|
self.return_residual = return_residual
|
||||||
|
self.fc1 = nn.Linear(in_features, 2 * hidden_features, **factory_kwargs)
|
||||||
|
self.activation = activation
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.fc1(x)
|
||||||
|
if self.activation == F.sigmoid: # Special case for GLU
|
||||||
|
y = F.glu(y, dim=-1)
|
||||||
|
else:
|
||||||
|
y, gate = y.chunk(2, dim=-1)
|
||||||
|
y = y * self.activation(gate)
|
||||||
|
y = self.fc2(y)
|
||||||
|
return y if not self.return_residual else (y, x)
|
||||||
|
|||||||
@ -404,7 +404,7 @@ def fused_mlp_func(
|
|||||||
|
|
||||||
class FusedMLP(nn.Module):
|
class FusedMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
|
def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True,
|
||||||
bias2=True, activation='gelu_approx', return_residual=False,
|
bias2=True, activation='gelu_approx', return_residual=False,
|
||||||
checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
|
checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
|
||||||
"""
|
"""
|
||||||
@ -432,8 +432,8 @@ class FusedMLP(nn.Module):
|
|||||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if out_features is None:
|
out_features = out_features or in_features
|
||||||
out_features = in_features
|
hidden_features = hidden_features or in_features * 4
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.return_residual = return_residual
|
self.return_residual = return_residual
|
||||||
self.checkpoint_lvl = checkpoint_lvl
|
self.checkpoint_lvl = checkpoint_lvl
|
||||||
@ -469,9 +469,9 @@ class FusedMLP(nn.Module):
|
|||||||
|
|
||||||
class ParallelFusedMLP(nn.Module):
|
class ParallelFusedMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_features, hidden_features, out_features=None, activation='gelu_approx',
|
def __init__(self, in_features, hidden_features=None, out_features=None,
|
||||||
process_group: ProcessGroup = None, bias1=True, bias2=True,
|
activation='gelu_approx', process_group: ProcessGroup = None,
|
||||||
sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
|
bias1=True, bias2=True, sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
|
||||||
device=None, dtype=None):
|
device=None, dtype=None):
|
||||||
"""
|
"""
|
||||||
process_group is required. We're doing Tensor Parallel with sequence parallelism:
|
process_group is required. We're doing Tensor Parallel with sequence parallelism:
|
||||||
@ -494,8 +494,8 @@ class ParallelFusedMLP(nn.Module):
|
|||||||
assert process_group is not None
|
assert process_group is not None
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if out_features is None:
|
out_features = out_features or in_features
|
||||||
out_features = in_features
|
hidden_features = hidden_features or in_features * 4
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.sequence_parallel = sequence_parallel
|
self.sequence_parallel = sequence_parallel
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user