diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index bbdd2ca..6292347 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -16,8 +16,9 @@ from transformers import GPT2Config from einops import rearrange +from flash_attn.ops.activations import sqrelu_fwd from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP +from flash_attn.modules.mlp import Mlp, GatedMlp, FusedMLP, ParallelFusedMLP from flash_attn.modules.block import Block, ParallelBlock from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.utils.distributed import sync_shared_params, all_gather_raw @@ -43,10 +44,9 @@ except ImportError: dropout_add_layer_norm_parallel_residual = None try: - from flash_attn.ops.triton.mlp import FusedDenseSqreluDense, sqrelu_fwd + from flash_attn.ops.triton.mlp import FusedDenseSqreluDense except ImportError: FusedDenseSqreluDense = None - sqrelu_fwd = None logger = logging.getLogger(__name__) @@ -90,7 +90,6 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} - inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size fused_mlp = getattr(config, 'fused_mlp', False) if fused_mlp: assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu'] @@ -102,17 +101,25 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp if process_group is not None: assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP' if not fused_mlp and not fused_dense_sqrelu_dense: - assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu'] - if config.activation_function == 'relu': - activation = partial(F.relu, inplace=True) - elif config.activation_function == 'sqrelu': - assert sqrelu_fwd is not None, 'sqrelu_fwd is not implemented' - activation = sqrelu_fwd + assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', + 'sqrelu', 'glu', 'swiglu', 'geglu'] + if config.activation_function in ['glu', 'swiglu', 'geglu']: + activation = (F.sigmoid if config.activation_function == 'glu' + else (F.silu if config.activation_function == 'swiglu' + else F.gelu)) + mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation, + **factory_kwargs) else: - approximate = ('tanh' if config.activation_function - in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none') - activation=partial(F.gelu, approximate=approximate) - mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs) + if config.activation_function == 'relu': + activation = partial(F.relu, inplace=True) + elif config.activation_function == 'sqrelu': + activation = sqrelu_fwd + else: + approximate = ('tanh' if config.activation_function + in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none') + activation=partial(F.gelu, approximate=approximate) + mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation, + **factory_kwargs) else: mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer @@ -128,12 +135,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp parallel_kwargs = ({'process_group': process_group, 'sequence_parallel': getattr(config, 'sequence_parallel', True)} if process_group is not None else {}) - mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation, + mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, checkpoint_lvl=mlp_checkpoint_lvl, **parallel_kwargs, **factory_kwargs) elif fused_dense_sqrelu_dense: assert FusedDenseSqreluDense is not None - mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim, + mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner, checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs) else: raise RuntimeError('MLP type not supported') @@ -252,7 +259,7 @@ class GPTModel(GPTPreTrainedModel): self.process_group = process_group self.sequence_parallel = getattr(config, 'sequence_parallel', True) assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', - 'relu', 'sqrelu'] + 'relu', 'sqrelu', 'glu', 'swiglu', 'geglu'] pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index 902bd3b..9a8da60 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -28,3 +28,28 @@ class Mlp(nn.Module): y = self.activation(y) y = self.fc2(y) return y if not self.return_residual else (y, x) + + +class GatedMlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid, + multiple_of=128, return_residual=False, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or int(8 * in_features / 3) + hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, 2 * hidden_features, **factory_kwargs) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + if self.activation == F.sigmoid: # Special case for GLU + y = F.glu(y, dim=-1) + else: + y, gate = y.chunk(2, dim=-1) + y = y * self.activation(gate) + y = self.fc2(y) + return y if not self.return_residual else (y, x) diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 1931145..d69d683 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -404,7 +404,7 @@ def fused_mlp_func( class FusedMLP(nn.Module): - def __init__(self, in_features, hidden_features, out_features=None, bias1=True, + def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True, bias2=True, activation='gelu_approx', return_residual=False, checkpoint_lvl=0, heuristic='auto', device=None, dtype=None): """ @@ -432,8 +432,8 @@ class FusedMLP(nn.Module): assert activation in ['gelu_approx', 'relu', 'sqrelu'] factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() - if out_features is None: - out_features = in_features + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 self.activation = activation self.return_residual = return_residual self.checkpoint_lvl = checkpoint_lvl @@ -469,9 +469,9 @@ class FusedMLP(nn.Module): class ParallelFusedMLP(nn.Module): - def __init__(self, in_features, hidden_features, out_features=None, activation='gelu_approx', - process_group: ProcessGroup = None, bias1=True, bias2=True, - sequence_parallel=True, checkpoint_lvl=0, heuristic='auto', + def __init__(self, in_features, hidden_features=None, out_features=None, + activation='gelu_approx', process_group: ProcessGroup = None, + bias1=True, bias2=True, sequence_parallel=True, checkpoint_lvl=0, heuristic='auto', device=None, dtype=None): """ process_group is required. We're doing Tensor Parallel with sequence parallelism: @@ -494,8 +494,8 @@ class ParallelFusedMLP(nn.Module): assert process_group is not None factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() - if out_features is None: - out_features = in_features + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 self.activation = activation self.process_group = process_group self.sequence_parallel = sequence_parallel