diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 88df046..0b6472b 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -18,7 +18,7 @@ from einops import rearrange from flash_attn.ops.activations import sqrelu_fwd from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import Mlp, GatedMlp, FusedMLP, ParallelFusedMLP +from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP from flash_attn.modules.block import Block, ParallelBlock from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.utils.distributed import sync_shared_params, all_gather_raw @@ -112,10 +112,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only ' 'supports approximate activation_function sqrelu') assert not (fused_dense_sqrelu_dense and fused_mlp) - if process_group is not None: - assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP' if not fused_mlp and not fused_dense_sqrelu_dense: - assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', + assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu', 'glu', 'swiglu', 'geglu'] if config.activation_function in ['glu', 'swiglu', 'geglu']: activation = (F.sigmoid if config.activation_function == 'glu' @@ -132,8 +130,13 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp approximate = ('tanh' if config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none') activation=partial(F.gelu, approximate=approximate) - mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation, - bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs) + mlp_cls = Mlp if process_group is None else ParallelMLP + parallel_kwargs = ({'process_group': process_group, + 'sequence_parallel': getattr(config, 'sequence_parallel', True)} + if process_group is not None else {}) + mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, + bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, + **parallel_kwargs, **factory_kwargs) else: mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index e19742d..af856b6 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -288,6 +288,8 @@ class ParallelBlock(nn.Module): hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). residual. """ + # TODO: Ideally we should only do the allgather / allreduce once for + # the Linear to MLP & Attention fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual if isinstance(self.norm1, RMSNorm) else dropout_add_layer_norm_parallel_residual) diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index 64661fd..f0c9acc 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -3,6 +3,12 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed import ProcessGroup + +try: + from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear +except ImportError: + ColumnParallelLinear, RowParallelLinear = None, None try: from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP @@ -30,6 +36,30 @@ class Mlp(nn.Module): return y if not self.return_residual else (y, x) +class ParallelMLP(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, + process_group: ProcessGroup = None, sequence_parallel=True, + bias1=True, bias2=True, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + assert ColumnParallelLinear is not None, "Need to install fused_dense" + assert RowParallelLinear is not None, "Need to install fused_dense" + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 + self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1, + sequence_parallel=sequence_parallel, **factory_kwargs) + self.activation = activation + self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2, + sequence_parallel=sequence_parallel, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + y = self.activation(y) + y = self.fc2(y) + return y + + class GatedMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,