Implement ParallelGatedMlp (#251)

This commit is contained in:
Haodong Lyu 2023-07-27 03:14:15 +08:00 committed by GitHub
parent 56ccaff126
commit 8ee62efca3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 152 additions and 5 deletions

View File

@ -18,7 +18,8 @@ from einops import rearrange
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP
from flash_attn.modules.mlp import Mlp, ParallelMLP, FusedMLP, ParallelFusedMLP
from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
@ -122,8 +123,13 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
activation = (F.sigmoid if config.activation_function == 'glu'
else (F.silu if config.activation_function == 'swiglu'
else F.gelu))
mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
**parallel_kwargs, **factory_kwargs)
else:
if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True)
@ -160,6 +166,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
**parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense:
if process_group is not None:
assert fused_mlp, 'Tensor Parallel is not implemented for FusedDenseSqreluDense'
assert FusedDenseSqreluDense is not None
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner,
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)

View File

@ -11,9 +11,10 @@ except ImportError:
ColumnParallelLinear, RowParallelLinear = None, None
try:
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP, ColumnParallelLinear, RowParallelLinear
except ImportError:
FusedMLP, ParallelFusedMLP = None, None
ColumnParallelLinear, RowParallelLinear = None, None
class Mlp(nn.Module):
@ -73,7 +74,7 @@ class GatedMlp(nn.Module):
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias1, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
def forward(self, x):
y = self.fc1(x)
@ -84,3 +85,27 @@ class GatedMlp(nn.Module):
y = y * self.activation(gate)
y = self.fc2(y)
return y if not self.return_residual else (y, x)
class ParallelGatedMlp(GatedMlp):
""" Parallel GatedMlp """
def __init__(self, in_features, process_group, hidden_features=None, out_features=None, activation=F.sigmoid,
bias1=True, bias2=True, multiple_of=256, return_residual=False,
sequence_parallel=True, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(in_features, hidden_features=hidden_features, out_features=out_features, activation=activation,
bias1=bias1, bias2=bias2, multiple_of=multiple_of, return_residual=return_residual,
device=device, dtype=dtype)
out_features = out_features or in_features
hidden_features = hidden_features or int(8 * in_features / 3)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed')
self.fc1 = ColumnParallelLinear(in_features, 2 * hidden_features, process_group,
bias=bias1,
sequence_parallel=sequence_parallel, **factory_kwargs)
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
bias=bias2,
sequence_parallel=sequence_parallel, **factory_kwargs)

View File

@ -0,0 +1,114 @@
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('activation', [F.silu, F.sigmoid])
# @pytest.mark.parametrize('activation', [F.silu])
@pytest.mark.parametrize('dim', [1024, 4096])
# @pytest.mark.parametrize('dim', [1024])
def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype,
requires_grad=True)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype)
partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size
model = ParallelGatedMlp(dim, parallel_state.get_tensor_model_parallel_group(),
activation=activation,
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
with torch.no_grad():
model.fc1.weight.copy_(
rearrange(rearrange(model_pt.fc1.weight, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o i -> (two o) i')
)
model.fc1.bias.copy_(
rearrange(rearrange(model_pt.fc1.bias, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o -> (two o)')
)
model.fc2.weight.copy_(
model_pt.fc2.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
)
if rank == 0:
model.fc2.bias.copy_(model_pt.fc2.bias)
out = model(x)
out_pt = model_pt(x_pt)
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol
)
assert torch.allclose(
model.fc1.weight.grad,
rearrange(rearrange(model_pt.fc1.weight.grad, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o i -> (two o) i'),
rtol=rtol, atol=atol
)
assert torch.allclose(
model.fc1.bias.grad,
rearrange(rearrange(model_pt.fc1.bias.grad, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o -> (two o)'),
rtol=rtol, atol=atol
)
assert torch.allclose(
model.fc2.weight.grad,
model_pt.fc2.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol
)
if rank == 0:
assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol)