[TP] Implement TensorParallel without sequence parallel

This commit is contained in:
Tri Dao 2023-01-07 13:45:22 -08:00
parent ce26d3d73d
commit 93383bd55b
11 changed files with 257 additions and 133 deletions

View File

@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_sequence_parallel_params
from flash_attn.utils.distributed import sync_shared_params
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import GenerationMixin
@ -62,7 +62,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
mha_cls = MHA if process_group is None else ParallelMHA
serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
if process_group is None else {})
parallel_kwargs = {'process_group': process_group} if process_group is not None else {}
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
@ -99,7 +101,9 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if FusedDenseGeluDense is None:
raise ImportError('fused_dense is not installed')
mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense
parallel_kwargs = {'process_group': process_group} if process_group is not None else {}
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl,
**parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense:
@ -113,13 +117,15 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
sequence_parallel = getattr(config, 'sequence_parallel', True)
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=True, resid_dropout=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
sequence_parallel=process_group is not None)
sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None)
block.layer_idx = layer_idx
return block
@ -180,6 +186,7 @@ class GPTModel(GPTPreTrainedModel):
super().__init__(config)
factory_kwargs = {'device': device, 'dtype': dtype}
self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu']
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if config.vocab_size % self.pad_vocab_size_multiple != 0:
@ -192,7 +199,8 @@ class GPTModel(GPTPreTrainedModel):
else:
self.embeddings = ParallelGPT2Embeddings(
config.hidden_size, config.vocab_size, config.max_position_embeddings,
process_group=process_group, **factory_kwargs
process_group=process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs
)
self.emb_drop = nn.Dropout(config.embd_pdrop)
@ -209,10 +217,13 @@ class GPTModel(GPTPreTrainedModel):
# is the final layer norm.
self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
**factory_kwargs)
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if process_group is not None:
for p in self.ln_0.parameters():
p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
p._shared_params = True
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
if self.sequence_parallel:
p._sequence_parallel = True
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
**factory_kwargs)
@ -224,14 +235,14 @@ class GPTModel(GPTPreTrainedModel):
def tie_weights(self):
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
sync_shared_params(self, self.process_group)
def forward(self, input_ids, position_ids=None, inference_params=None):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen.
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
if self.process_group is not None else {})
if self.process_group is not None and self.sequence_parallel else {})
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if not self.fused_dropout_add_ln:
@ -243,7 +254,8 @@ class GPTModel(GPTPreTrainedModel):
self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
residual_in_fp32=True
)
mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {})
mixer_kwargs = ({'seqlen': input_ids.shape[1]}
if self.process_group is not None and self.sequence_parallel else {})
if inference_params is not None:
mixer_kwargs['inference_params'] = inference_params
for layer in self.layers:
@ -263,8 +275,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
else:
if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed')
self.lm_head = ColumnParallelLinear(config.n_embd, config.vocab_size, process_group,
bias=False, **factory_kwargs)
self.lm_head = ColumnParallelLinear(
config.n_embd, config.vocab_size, process_group, bias=False,
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
)
# Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
@ -273,7 +287,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def tie_weights(self):
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
sync_shared_params(self, self.process_group)
def forward(self, input_ids, position_ids=None, inference_params=None):
"""

View File

@ -23,7 +23,8 @@ class Block(nn.Module):
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0.,
fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False):
fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False,
mark_shared_params=False):
"""
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
@ -51,6 +52,12 @@ class Block(nn.Module):
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if sequence_parallel:
for p in self.norm1.parameters():
@ -58,6 +65,13 @@ class Block(nn.Module):
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._shared_params = True
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_kwargs=None):

View File

@ -6,7 +6,7 @@ from torch import Tensor
from einops import rearrange
from flash_attn.utils.distributed import reduce_scatter
from flash_attn.utils.distributed import reduce_scatter, all_reduce
class GPT2Embeddings(nn.Module):
@ -130,13 +130,14 @@ class ColumnParallelEmbedding(nn.Embedding):
class ParallelGPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
padding_idx=None, device=None, dtype=None):
padding_idx=None, sequence_parallel=True, device=None, dtype=None):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.word_embeddings = VocabParallelEmbedding(
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
**factory_kwargs
@ -167,4 +168,5 @@ class ParallelGPT2Embeddings(nn.Module):
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
return embeddings if world_size <= 1 else reduce_scatter(embeddings, self.process_group)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)

View File

@ -497,11 +497,10 @@ class ParallelMHA(nn.Module):
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0,
rotary_emb_scale_base=0,
use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None:
rotary_emb_scale_base=0, use_flash_attn=False, checkpointing=False,
sequence_parallel=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.process_group = process_group
self.embed_dim = embed_dim
self.causal = causal
self.layer_idx = layer_idx
@ -521,12 +520,13 @@ class ParallelMHA(nn.Module):
if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed')
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
**factory_kwargs)
sequence_parallel=sequence_parallel, **factory_kwargs)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, **factory_kwargs)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
sequence_parallel=sequence_parallel, **factory_kwargs)
def forward(self, x, seqlen=None, **kwargs):
"""

View File

@ -15,26 +15,29 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.gelu_activation import gelu_bwd
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, reduce_scatter
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
from flash_attn.utils.distributed import reduce_scatter, all_reduce
class FusedDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None):
def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
sequence_parallel=True):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather_raw of x before doing the matmul.
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None:
if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
@ -44,7 +47,7 @@ class FusedDenseFunc(torch.autograd.Function):
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
weight = weight.contiguous()
if process_group is not None:
if process_group is not None and sequence_parallel:
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
@ -66,9 +69,10 @@ class FusedDenseFunc(torch.autograd.Function):
grad_input, = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
if process_group is not None:
if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
@ -86,13 +90,13 @@ class FusedDenseFunc(torch.autograd.Function):
grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group,
async_op=True)
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
if process_group is not None:
if process_group is not None and sequence_parallel:
handle_x.wait()
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
@ -102,15 +106,17 @@ class FusedDenseFunc(torch.autograd.Function):
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None
return grad_input, grad_weight, grad_bias, None, None, None
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
return_residual: bool = False, process_group: Optional[ProcessGroup] = None):
return_residual: bool = False, process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True):
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group,
sequence_parallel)
else:
assert process_group is None
out = F.linear(x, weight, bias)
@ -136,7 +142,7 @@ class FusedDense(nn.Linear):
class ColumnParallelLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
bias: bool = True, device=None, dtype=None) -> None:
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % world_size != 0:
raise ValueError(f'out_features ({out_features}) must be divisible by '
@ -144,19 +150,20 @@ class ColumnParallelLinear(nn.Linear):
super().__init__(in_features, out_features // world_size, bias=bias,
device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do an all_gather of
x before doing the matmul.
"""
return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group,
sequence_parallel=self.sequence_parallel)
class RowParallelLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
bias: bool = True, device=None, dtype=None) -> None:
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
world_size = torch.distributed.get_world_size(process_group)
rank = torch.distributed.get_rank(process_group)
if in_features % world_size != 0:
@ -166,6 +173,7 @@ class RowParallelLinear(nn.Linear):
super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
def forward(self, x):
"""
@ -173,7 +181,8 @@ class RowParallelLinear(nn.Linear):
a reduce_scatter of the result.
"""
out = fused_dense_func(x, self.weight, self.bias)
return reduce_scatter(out, self.process_group)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class FusedDenseGeluDenseFunc(torch.autograd.Function):
@ -181,10 +190,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False,
checkpoint_lvl=0, heuristic=0, process_group=None):
checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather of x before doing the matmul.
If sequence_parallel=False, then the input is already gathered.
checkpoint_lvl:
0: no recomputation in the bwd
@ -197,13 +207,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
assert checkpoint_lvl in [0, 1, 2]
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None:
if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
@ -218,7 +229,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
bias1 = bias1.contiguous() if bias1 is not None else None
weight2 = weight2.contiguous()
bias2 = bias2.contiguous() if bias2 is not None else None
if process_group is not None:
if process_group is not None and sequence_parallel:
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
@ -257,13 +268,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_input, = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
x, weight1, weight2, *rest = ctx.saved_tensors
if process_group is None:
if process_group is None or not sequence_parallel:
total_x = x
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
if checkpoint_lvl in [0, 1]:
if process_group is not None:
if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
if checkpoint_lvl == 0:
gelu_in, output1 = rest
@ -272,7 +284,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
output1 = F.gelu(gelu_in, approximate='tanh')
elif checkpoint_lvl == 2:
bias1, = rest
if process_group is not None:
if process_group is not None and sequence_parallel:
total_x, _ = all_gather_raw(x, process_group)
if ctx.heuristic == -1:
gelu_in = F.linear(total_x, weight1, bias1)
@ -314,13 +326,13 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_gelu, weight1)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group,
async_op=True)
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.heuristic == -1:
if ctx.needs_input_grad[1]:
if process_group is not None:
if process_group is not None and sequence_parallel:
handle_x.wait()
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu,
@ -331,7 +343,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
else:
if ctx.needs_input_grad[1]:
if process_group is not None:
if process_group is not None and sequence_parallel:
handle_x.wait()
grad_weight1 = F.linear(grad_gelu.t(),
total_x.reshape(batch_dim, total_x.shape[-1]).t())
@ -340,7 +352,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
None, None, None, None, None)
None, None, None, None, None, None)
def fused_dense_gelu_dense_func(
@ -348,15 +360,16 @@ def fused_dense_gelu_dense_func(
bias2: Optional[Tensor] = None,
save_pre_act: bool = True, return_residual: bool = False,
checkpoint_lvl: int = 0, heuristic: int = 0,
process_group: Optional[ProcessGroup] = None
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True
):
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
and (bias2 is None or bias2.is_cuda) and dtype_eligible):
return FusedDenseGeluDenseFunc.apply(
x, weight1, bias1, weight2, bias2,
save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group
x, weight1, bias1, weight2, bias2, save_pre_act, return_residual,
checkpoint_lvl, heuristic, process_group, sequence_parallel
)
else:
assert process_group is None
@ -418,7 +431,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
def __init__(self, in_features, hidden_features, out_features=None,
process_group: ProcessGroup = None, bias1=True, bias2=True,
checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
sequence_parallel=True, checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
@ -441,6 +454,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
if out_features is None:
out_features = in_features
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
@ -452,6 +466,9 @@ class ParallelFusedDenseGeluDense(nn.Module):
out = fused_dense_gelu_dense_func(
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl,
heuristic=self.heuristic, process_group=self.process_group
heuristic=self.heuristic,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel
)
return reduce_scatter(out, self.process_group)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)

View File

@ -14,7 +14,7 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
# Raw operation, oes does support autograd, but does support async
# Raw operation, does not support autograd, but does support async
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
@ -24,7 +24,7 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
return output, handle
# Raw operation, oes does support autograd, but does support async
# Raw operation, does not support autograd, but does support async
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
@ -36,6 +36,13 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
return output, handle
# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
return input_, handle
class AllGatherFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""
@ -74,12 +81,30 @@ class ReduceScatterFunc(torch.autograd.Function):
reduce_scatter = ReduceScatterFunc.apply
def sync_sequence_parallel_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
class AllReduceFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""
@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_reduce_raw(input_, process_group)
return output
@staticmethod
def backward(ctx, grad_output: Tensor):
return grad_output, None
# Supports autograd, but does not support async
all_reduce = AllReduceFunc.apply
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)}
for _, p in sorted(params_seqparallel.items()):
pamams_shared = {name: p for name, p in model.named_parameters()
if getattr(p, '_shared_params', False)}
for _, p in sorted(pamams_shared.items()):
with torch.no_grad():
# Broadcast needs src to be global rank, not group rank
torch.distributed.broadcast(
@ -94,8 +119,9 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc
params_seqparallel = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)}
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
with torch.no_grad():
coalesced = torch._utils._flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=process_group)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
if grads:
with torch.no_grad():
coalesced = torch._utils._flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=process_group)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

View File

@ -23,10 +23,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024])
def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
num_heads = dim // head_dim
@ -59,7 +61,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size)
pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel)
model_pt = GPTLMHeadModel(config, device=device)
def init_layer_norm(module):
@ -75,16 +78,15 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
torch.distributed.all_gather_into_tensor(
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
)
sequence_parallel_nparams = sum(p.numel() for p in model.parameters()
if getattr(p, '_sequence_parallel', False))
sequence_parallel_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
shared_nparams = sum(p.numel() for p in model.parameters()
if getattr(p, '_shared_params', False))
shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor(
sequence_parallel_nparams_all, torch.tensor([sequence_parallel_nparams], device=device),
group=process_group
shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
)
assert torch.all(sequence_parallel_nparams_all == sequence_parallel_nparams)
assert total_nparams == ((sharded_nparams_all - sequence_parallel_nparams_all).sum().item()
+ sequence_parallel_nparams)
assert torch.all(shared_nparams_all == shared_nparams)
assert total_nparams == ((sharded_nparams_all - shared_nparams_all).sum().item()
+ shared_nparams)
# vocab_size has been rounded up here
partition_vocab_size = config.vocab_size // world_size
@ -96,6 +98,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
with torch.autocast(device_type='cuda', dtype=dtype):
out = model(input_ids[:, :-1]).logits
if not sequence_parallel:
out = rearrange(out, 'b s d -> (b s) d')
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d')
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(

View File

@ -23,11 +23,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('dim', [1024])
def test_block_parallel(dim, world_size, dtype):
def test_block_parallel(dim, sequence_parallel, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
num_heads = dim // head_dim
@ -41,7 +43,7 @@ def test_block_parallel(dim, world_size, dtype):
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 8
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype,
@ -51,8 +53,12 @@ def test_block_parallel(dim, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_()
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_()
else:
x = x_pt.detach().clone().requires_grad_()
residual = residual_pt.detach().clone().requires_grad_()
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype)
@ -69,12 +75,12 @@ def test_block_parallel(dim, world_size, dtype):
mixer_cls = partial(ParallelMHA, num_heads=num_heads,
process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
device=device, dtype=dtype)
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
mlp_cls = partial(ParallelFusedDenseGeluDense, hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(),
device=device, dtype=dtype)
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True,
sequence_parallel=True)
sequence_parallel=sequence_parallel, mark_shared_params=True)
partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size
@ -115,25 +121,34 @@ def test_block_parallel(dim, world_size, dtype):
out_pt, out_residual_pt = [rearrange(x, 'b s d -> (b s) d') for x in [out_pt, out_residual_pt]]
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
)
assert torch.allclose(
out_residual, out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
out_residual,
out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_residual_pt,
rtol=rtol, atol=atol
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
(out_pt + 2 * out_residual_pt).backward(g)
(out + 2 * out_residual).backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
rtol=rtol, atol=atol
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
)
assert torch.allclose(
residual.grad, residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
residual.grad,
residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else residual_pt.grad,
rtol=rtol, atol=atol
)
# The error for d_weight and d_bias is quite a bit higher

View File

@ -19,10 +19,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024])
def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
vocab_size = 50264
seqlen = 2048
assert vocab_size % world_size == 0
@ -46,7 +48,7 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
device=device, dtype=dtype)
model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0,
parallel_state.get_tensor_model_parallel_group(),
device=device, dtype=dtype)
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
partition_vocab_size = vocab_size // world_size
partition_dim = dim // world_size
with torch.no_grad():
@ -62,13 +64,16 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d')
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
)
g = torch.randn_like(out_pt)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
parallel_state.destroy_model_parallel()
assert torch.allclose(

View File

@ -21,11 +21,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('head_dim', [64, 128])
# @pytest.mark.parametrize('head_dim', [64])
@pytest.mark.parametrize('embed_dim', [1024, 4096])
# @pytest.mark.parametrize('embed_dim', [1024])
def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim
assert num_heads % world_size == 0
@ -38,7 +40,7 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 8
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype,
@ -47,14 +49,17 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype)
partition_dim = embed_dim // world_size
model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
device=device, dtype=dtype)
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
with torch.no_grad():
model.Wqkv.weight.copy_(
@ -75,17 +80,22 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d')
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
rtol=rtol, atol=atol
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(

View File

@ -19,14 +19,15 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_bias', [True, False])
# @pytest.mark.parametrize('has_bias', [True])
@pytest.mark.parametrize('out_features', [1024, 4096])
# @pytest.mark.parametrize('out_features', [1024])
@pytest.mark.parametrize('in_features', [1024, 4096])
# @pytest.mark.parametrize('in_features', [4096])
def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtype):
# @pytest.mark.parametrize('has_bias', [False])
@pytest.mark.parametrize('out_features', [1024])
@pytest.mark.parametrize('in_features', [4096])
def test_fused_linear_bias(in_features, out_features, has_bias, sequence_parallel,
world_size, dtype):
assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
@ -37,18 +38,21 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 8
batch_size = 2
seqlen = 512
assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
requires_grad=True)
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
partition_out_features = out_features // world_size
model = ColumnParallelLinear(in_features, out_features,
parallel_state.get_tensor_model_parallel_group(), bias=has_bias,
device=device, dtype=dtype)
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
with torch.no_grad():
model.weight.copy_(
model_pt.weight[rank * partition_out_features:(rank + 1) * partition_out_features]
@ -73,7 +77,9 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol
)
# The error for d_weight and d_bias is quite a bit higher
@ -94,13 +100,14 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_bias2', [True, False])
# @pytest.mark.parametrize('has_bias2', [True])
@pytest.mark.parametrize('out_features', [1024, 4096])
# @pytest.mark.parametrize('out_features', [1024])
@pytest.mark.parametrize('in_features', [1024, 4096])
# @pytest.mark.parametrize('in_features', [1024])
def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size, dtype):
@pytest.mark.parametrize('out_features', [4096])
@pytest.mark.parametrize('in_features', [1024])
def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_parallel,
world_size, dtype):
assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
@ -111,7 +118,7 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 8
batch_size = 2
seqlen = 512
assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
@ -120,7 +127,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
else:
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
@ -129,7 +139,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
partition_in_features = in_features // world_size
model = ParallelFusedDenseGeluDense(in_features, out_features, in_features,
process_group=parallel_state.get_tensor_model_parallel_group(),
bias2=has_bias2 and rank == 0, device=device, dtype=dtype)
bias2=has_bias2 and rank == 0,
sequence_parallel=sequence_parallel,
device=device, dtype=dtype)
with torch.no_grad():
model.fc1.weight.copy_(
@ -148,16 +160,21 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol
)
# The error for d_weight and d_bias is quite a bit higher