diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 1dc3d0c..745b5d5 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -18,7 +18,8 @@ from einops import rearrange from flash_attn.ops.activations import sqrelu_fwd from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP +from flash_attn.modules.mlp import Mlp, ParallelMLP, FusedMLP, ParallelFusedMLP +from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp from flash_attn.modules.block import Block, ParallelBlock from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.utils.distributed import sync_shared_params, all_gather_raw @@ -122,8 +123,13 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp activation = (F.sigmoid if config.activation_function == 'glu' else (F.silu if config.activation_function == 'swiglu' else F.gelu)) - mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation, - bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs) + mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp + parallel_kwargs = ({'process_group': process_group, + 'sequence_parallel': getattr(config, 'sequence_parallel', True)} + if process_group is not None else {}) + mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, + bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, + **parallel_kwargs, **factory_kwargs) else: if config.activation_function == 'relu': activation = partial(F.relu, inplace=True) @@ -160,6 +166,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **parallel_kwargs, **factory_kwargs) elif fused_dense_sqrelu_dense: + if process_group is not None: + assert fused_mlp, 'Tensor Parallel is not implemented for FusedDenseSqreluDense' assert FusedDenseSqreluDense is not None mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner, checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs) diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index f0c9acc..0e74d19 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -11,9 +11,10 @@ except ImportError: ColumnParallelLinear, RowParallelLinear = None, None try: - from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP + from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP, ColumnParallelLinear, RowParallelLinear except ImportError: FusedMLP, ParallelFusedMLP = None, None + ColumnParallelLinear, RowParallelLinear = None, None class Mlp(nn.Module): @@ -73,7 +74,7 @@ class GatedMlp(nn.Module): self.return_residual = return_residual self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs) self.activation = activation - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias1, **factory_kwargs) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) def forward(self, x): y = self.fc1(x) @@ -84,3 +85,27 @@ class GatedMlp(nn.Module): y = y * self.activation(gate) y = self.fc2(y) return y if not self.return_residual else (y, x) + + +class ParallelGatedMlp(GatedMlp): + """ Parallel GatedMlp """ + + def __init__(self, in_features, process_group, hidden_features=None, out_features=None, activation=F.sigmoid, + bias1=True, bias2=True, multiple_of=256, return_residual=False, + sequence_parallel=True, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(in_features, hidden_features=hidden_features, out_features=out_features, activation=activation, + bias1=bias1, bias2=bias2, multiple_of=multiple_of, return_residual=return_residual, + device=device, dtype=dtype) + out_features = out_features or in_features + hidden_features = hidden_features or int(8 * in_features / 3) + hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of + + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError('fused_dense is not installed') + self.fc1 = ColumnParallelLinear(in_features, 2 * hidden_features, process_group, + bias=bias1, + sequence_parallel=sequence_parallel, **factory_kwargs) + self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, + bias=bias2, + sequence_parallel=sequence_parallel, **factory_kwargs) diff --git a/tests/modules/test_mlp_parallel.py b/tests/modules/test_mlp_parallel.py new file mode 100644 index 0000000..1601e71 --- /dev/null +++ b/tests/modules/test_mlp_parallel.py @@ -0,0 +1,114 @@ +# Run test with: +# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from apex.transformer import parallel_state +from apex.transformer import tensor_parallel + +from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp + +is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 + + +@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) +# @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) +# @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('sequence_parallel', [True, False]) +# @pytest.mark.parametrize('sequence_parallel', [False]) +@pytest.mark.parametrize('activation', [F.silu, F.sigmoid]) +# @pytest.mark.parametrize('activation', [F.silu]) +@pytest.mark.parametrize('dim', [1024, 4096]) +# @pytest.mark.parametrize('dim', [1024]) +def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype): + rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + # set seed + torch.random.manual_seed(0) + batch_size = 2 + seqlen = 1024 + assert (batch_size * seqlen) % world_size == 0 + x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, + requires_grad=True) + # We need to generate g here so that all processes get the same gradient, + # as rank 0 will have an extra bias that changes the RNG. + # If we don't divide by batch_size, the gradient gets a bit too large. + g = torch.randn_like(x_pt) / 32 + if sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + else: + x = x_pt.detach().clone().requires_grad_() + + model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype) + partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size + model = ParallelGatedMlp(dim, parallel_state.get_tensor_model_parallel_group(), + activation=activation, + sequence_parallel=sequence_parallel, device=device, dtype=dtype) + + with torch.no_grad(): + model.fc1.weight.copy_( + rearrange(rearrange(model_pt.fc1.weight, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], + 'two o i -> (two o) i') + ) + model.fc1.bias.copy_( + rearrange(rearrange(model_pt.fc1.bias, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], + 'two o -> (two o)') + ) + model.fc2.weight.copy_( + model_pt.fc2.weight[:, rank * partition_dim:(rank + 1) * partition_dim] + ) + if rank == 0: + model.fc2.bias.copy_(model_pt.fc2.bias) + + out = model(x) + out_pt = model_pt(x_pt) + partition_batch_dim = batch_size * seqlen // world_size + assert torch.allclose( + out, + out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else out_pt, + rtol=rtol, atol=atol + ) + + out_pt.backward(g) + out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else g) + parallel_state.destroy_model_parallel() + + assert torch.allclose( + x.grad, + x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else x_pt.grad, + rtol=rtol, atol=atol + ) + + assert torch.allclose( + model.fc1.weight.grad, + rearrange(rearrange(model_pt.fc1.weight.grad, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], + 'two o i -> (two o) i'), + rtol=rtol, atol=atol + ) + assert torch.allclose( + model.fc1.bias.grad, + rearrange(rearrange(model_pt.fc1.bias.grad, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], + 'two o -> (two o)'), + rtol=rtol, atol=atol + ) + assert torch.allclose( + model.fc2.weight.grad, + model_pt.fc2.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], + rtol=rtol, atol=atol + ) + if rank == 0: + assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol)