From 93383bd55bfffb0fa2c4584c4849971152397035 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 7 Jan 2023 13:45:22 -0800 Subject: [PATCH] [TP] Implement TensorParallel without sequence parallel --- flash_attn/models/gpt.py | 40 ++++++---- flash_attn/modules/block.py | 16 +++- flash_attn/modules/embedding.py | 8 +- flash_attn/modules/mha.py | 10 +-- flash_attn/ops/fused_dense.py | 97 ++++++++++++++---------- flash_attn/utils/distributed.py | 50 +++++++++--- tests/models/test_gpt_parallel.py | 24 +++--- tests/modules/test_block_parallel.py | 45 +++++++---- tests/modules/test_embedding_parallel.py | 13 +++- tests/modules/test_mha_parallel.py | 26 +++++-- tests/ops/test_fused_dense_parallel.py | 61 +++++++++------ 11 files changed, 257 insertions(+), 133 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 417fac9..656bf9e 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense from flash_attn.modules.block import Block from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings -from flash_attn.utils.distributed import sync_sequence_parallel_params +from flash_attn.utils.distributed import sync_shared_params from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.generation import GenerationMixin @@ -62,7 +62,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt mha_cls = MHA if process_group is None else ParallelMHA serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv} if process_group is None else {}) - parallel_kwargs = {'process_group': process_group} if process_group is not None else {} + parallel_kwargs = ({'process_group': process_group, + 'sequence_parallel': getattr(config, 'sequence_parallel', True)} + if process_group is not None else {}) mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop, softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base, @@ -99,7 +101,9 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp if FusedDenseGeluDense is None: raise ImportError('fused_dense is not installed') mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense - parallel_kwargs = {'process_group': process_group} if process_group is not None else {} + parallel_kwargs = ({'process_group': process_group, + 'sequence_parallel': getattr(config, 'sequence_parallel', True)} + if process_group is not None else {}) mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl, **parallel_kwargs, **factory_kwargs) elif fused_dense_sqrelu_dense: @@ -113,13 +117,15 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} + sequence_parallel = getattr(config, 'sequence_parallel', True) mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs) block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, prenorm=True, resid_dropout=config.resid_pdrop, fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), - sequence_parallel=process_group is not None) + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None) block.layer_idx = layer_idx return block @@ -180,6 +186,7 @@ class GPTModel(GPTPreTrainedModel): super().__init__(config) factory_kwargs = {'device': device, 'dtype': dtype} self.process_group = process_group + self.sequence_parallel = getattr(config, 'sequence_parallel', True) assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu'] self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) if config.vocab_size % self.pad_vocab_size_multiple != 0: @@ -192,7 +199,8 @@ class GPTModel(GPTPreTrainedModel): else: self.embeddings = ParallelGPT2Embeddings( config.hidden_size, config.vocab_size, config.max_position_embeddings, - process_group=process_group, **factory_kwargs + process_group=process_group, sequence_parallel=self.sequence_parallel, + **factory_kwargs ) self.emb_drop = nn.Dropout(config.embd_pdrop) @@ -209,10 +217,13 @@ class GPTModel(GPTPreTrainedModel): # is the final layer norm. self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs) - # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. if process_group is not None: for p in self.ln_0.parameters(): - p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + p._shared_params = True + # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. + if self.sequence_parallel: + p._sequence_parallel = True self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs) @@ -224,14 +235,14 @@ class GPTModel(GPTPreTrainedModel): def tie_weights(self): if self.process_group is not None: - sync_sequence_parallel_params(self, self.process_group) + sync_shared_params(self, self.process_group) def forward(self, input_ids, position_ids=None, inference_params=None): # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # dimensions so that we can split on it easily, in case of small batch size. # Only the attention layers need to know the seqlen. embedding_kwargs = ({'combine_batch_seqlen_dim': True} - if self.process_group is not None else {}) + if self.process_group is not None and self.sequence_parallel else {}) hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable if not self.fused_dropout_add_ln: @@ -243,7 +254,8 @@ class GPTModel(GPTPreTrainedModel): self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True, residual_in_fp32=True ) - mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {}) + mixer_kwargs = ({'seqlen': input_ids.shape[1]} + if self.process_group is not None and self.sequence_parallel else {}) if inference_params is not None: mixer_kwargs['inference_params'] = inference_params for layer in self.layers: @@ -263,8 +275,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): else: if ColumnParallelLinear is None: raise ImportError('fused_dense_lib is not installed') - self.lm_head = ColumnParallelLinear(config.n_embd, config.vocab_size, process_group, - bias=False, **factory_kwargs) + self.lm_head = ColumnParallelLinear( + config.n_embd, config.vocab_size, process_group, bias=False, + sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs + ) # Initialize weights and apply final processing self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, initializer_range=config.initializer_range)) @@ -273,7 +287,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): def tie_weights(self): self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight if self.process_group is not None: - sync_sequence_parallel_params(self, self.process_group) + sync_shared_params(self, self.process_group) def forward(self, input_ids, position_ids=None, inference_params=None): """ diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index cdfb61b..5043733 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -23,7 +23,8 @@ class Block(nn.Module): def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0., - fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False): + fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False, + mark_shared_params=False): """ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. This is for performance reason: for post-norm architecture, returning the input allows us @@ -51,6 +52,12 @@ class Block(nn.Module): assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed' assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) + # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, + # then the input to each worker in the tensor parallel group will be different. + # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. + # For now this is not an issue because we always use sequence_parallel=True during training + # and only use sequence_parallel=False during inference. + # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. if sequence_parallel: for p in self.norm1.parameters(): @@ -58,6 +65,13 @@ class Block(nn.Module): if hasattr(self, 'norm2'): for p in self.norm2.parameters(): p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, 'norm2'): + for p in self.norm2.parameters(): + p._shared_params = True def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_kwargs=None): diff --git a/flash_attn/modules/embedding.py b/flash_attn/modules/embedding.py index 0db86c0..eee184f 100644 --- a/flash_attn/modules/embedding.py +++ b/flash_attn/modules/embedding.py @@ -6,7 +6,7 @@ from torch import Tensor from einops import rearrange -from flash_attn.utils.distributed import reduce_scatter +from flash_attn.utils.distributed import reduce_scatter, all_reduce class GPT2Embeddings(nn.Module): @@ -130,13 +130,14 @@ class ColumnParallelEmbedding(nn.Embedding): class ParallelGPT2Embeddings(nn.Module): def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group, - padding_idx=None, device=None, dtype=None): + padding_idx=None, sequence_parallel=True, device=None, dtype=None): """ If max_position_embeddings <= 0, there's no position embeddings """ factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.process_group = process_group + self.sequence_parallel = sequence_parallel self.word_embeddings = VocabParallelEmbedding( vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group, **factory_kwargs @@ -167,4 +168,5 @@ class ParallelGPT2Embeddings(nn.Module): embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings if combine_batch_seqlen_dim: embeddings = rearrange(embeddings, 'b s d -> (b s) d') - return embeddings if world_size <= 1 else reduce_scatter(embeddings, self.process_group) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 129d233..3790f08 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -497,11 +497,10 @@ class ParallelMHA(nn.Module): def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0, - rotary_emb_scale_base=0, - use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None: + rotary_emb_scale_base=0, use_flash_attn=False, checkpointing=False, + sequence_parallel=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() - self.process_group = process_group self.embed_dim = embed_dim self.causal = causal self.layer_idx = layer_idx @@ -521,12 +520,13 @@ class ParallelMHA(nn.Module): if ColumnParallelLinear is None or RowParallelLinear is None: raise ImportError('fused_dense is not installed') self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias, - **factory_kwargs) + sequence_parallel=sequence_parallel, **factory_kwargs) inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) # output projection always have the bias (for now) - self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, **factory_kwargs) + self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, + sequence_parallel=sequence_parallel, **factory_kwargs) def forward(self, x, seqlen=None, **kwargs): """ diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index c8d6e0f..b4e3e28 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -15,26 +15,29 @@ from torch.cuda.amp import custom_bwd, custom_fwd import fused_dense_lib as fused_dense_cuda from flash_attn.ops.gelu_activation import gelu_bwd -from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, reduce_scatter +from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw +from flash_attn.utils.distributed import reduce_scatter, all_reduce class FusedDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None): + def forward(ctx, x, weight, bias, return_residual=False, process_group=None, + sequence_parallel=True): """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather_raw of x before doing the matmul. + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather_raw of x before doing the matmul. """ ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.contiguous() - if process_group is not None: + if process_group is not None and sequence_parallel: # We want to kick off the all_gather early, before weight dtype conversion total_x, handle_x = all_gather_raw(x, process_group, async_op=True) else: @@ -44,7 +47,7 @@ class FusedDenseFunc(torch.autograd.Function): weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None weight = weight.contiguous() - if process_group is not None: + if process_group is not None and sequence_parallel: handle_x.wait() batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() @@ -66,9 +69,10 @@ class FusedDenseFunc(torch.autograd.Function): grad_input, = args grad_input = grad_input.contiguous() process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel if ctx.compute_weight_gradient: x, weight = ctx.saved_tensors - if process_group is not None: + if process_group is not None and sequence_parallel: total_x, handle_x = all_gather_raw(x, process_group, async_op=True) else: total_x = x @@ -86,13 +90,13 @@ class FusedDenseFunc(torch.autograd.Function): grad_output, weight) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: - grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, - async_op=True) + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) else: grad_input = None if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient - if process_group is not None: + if process_group is not None and sequence_parallel: handle_x.wait() grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] @@ -102,15 +106,17 @@ class FusedDenseFunc(torch.autograd.Function): grad_bias = grad_output if ctx.needs_input_grad[2] else None if process_group is not None and ctx.needs_input_grad[0]: handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None + return grad_input, grad_weight, grad_bias, None, None, None def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, - return_residual: bool = False, process_group: Optional[ProcessGroup] = None): + return_residual: bool = False, process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True): dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] or (x.dtype == torch.float32 and torch.is_autocast_enabled())) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group) + return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, + sequence_parallel) else: assert process_group is None out = F.linear(x, weight, bias) @@ -136,7 +142,7 @@ class FusedDense(nn.Linear): class ColumnParallelLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, - bias: bool = True, device=None, dtype=None) -> None: + bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None: world_size = torch.distributed.get_world_size(process_group) if out_features % world_size != 0: raise ValueError(f'out_features ({out_features}) must be divisible by ' @@ -144,19 +150,20 @@ class ColumnParallelLinear(nn.Linear): super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype) self.process_group = process_group + self.sequence_parallel = sequence_parallel def forward(self, x): - """ - We're doing Tensor Parallel with sequence parallelism: we do an all_gather of - x before doing the matmul. - """ - return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group) + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, + sequence_parallel=self.sequence_parallel) class RowParallelLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, - bias: bool = True, device=None, dtype=None) -> None: + bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None: world_size = torch.distributed.get_world_size(process_group) rank = torch.distributed.get_rank(process_group) if in_features % world_size != 0: @@ -166,6 +173,7 @@ class RowParallelLinear(nn.Linear): super().__init__(in_features // world_size, out_features, bias=bias and rank == 0, device=device, dtype=dtype) self.process_group = process_group + self.sequence_parallel = sequence_parallel def forward(self, x): """ @@ -173,7 +181,8 @@ class RowParallelLinear(nn.Linear): a reduce_scatter of the result. """ out = fused_dense_func(x, self.weight, self.bias) - return reduce_scatter(out, self.process_group) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) class FusedDenseGeluDenseFunc(torch.autograd.Function): @@ -181,10 +190,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False, - checkpoint_lvl=0, heuristic=0, process_group=None): + checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True): """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather of x before doing the matmul. + If sequence_parallel=False, then the input is already gathered. checkpoint_lvl: 0: no recomputation in the bwd @@ -197,13 +207,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): assert checkpoint_lvl in [0, 1, 2] ctx.return_residual = return_residual ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel ctx.checkpoint_lvl = checkpoint_lvl ctx.heuristic = heuristic if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.contiguous() - if process_group is not None: + if process_group is not None and sequence_parallel: # We want to kick off the all_gather early, before weight dtype conversion total_x, handle_x = all_gather_raw(x, process_group, async_op=True) else: @@ -218,7 +229,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): bias1 = bias1.contiguous() if bias1 is not None else None weight2 = weight2.contiguous() bias2 = bias2.contiguous() if bias2 is not None else None - if process_group is not None: + if process_group is not None and sequence_parallel: handle_x.wait() batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() @@ -257,13 +268,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): grad_input, = args grad_input = grad_input.contiguous() process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel x, weight1, weight2, *rest = ctx.saved_tensors - if process_group is None: + if process_group is None or not sequence_parallel: total_x = x batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() if checkpoint_lvl in [0, 1]: - if process_group is not None: + if process_group is not None and sequence_parallel: total_x, handle_x = all_gather_raw(x, process_group, async_op=True) if checkpoint_lvl == 0: gelu_in, output1 = rest @@ -272,7 +284,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): output1 = F.gelu(gelu_in, approximate='tanh') elif checkpoint_lvl == 2: bias1, = rest - if process_group is not None: + if process_group is not None and sequence_parallel: total_x, _ = all_gather_raw(x, process_group) if ctx.heuristic == -1: gelu_in = F.linear(total_x, weight1, bias1) @@ -314,13 +326,13 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): grad_gelu, weight1) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: - grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, - async_op=True) + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) else: grad_input = None if ctx.heuristic == -1: if ctx.needs_input_grad[1]: - if process_group is not None: + if process_group is not None and sequence_parallel: handle_x.wait() grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu, @@ -331,7 +343,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None else: if ctx.needs_input_grad[1]: - if process_group is not None: + if process_group is not None and sequence_parallel: handle_x.wait() grad_weight1 = F.linear(grad_gelu.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()) @@ -340,7 +352,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): if process_group is not None and ctx.needs_input_grad[0]: handle_grad_input.wait() return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, - None, None, None, None, None) + None, None, None, None, None, None) def fused_dense_gelu_dense_func( @@ -348,15 +360,16 @@ def fused_dense_gelu_dense_func( bias2: Optional[Tensor] = None, save_pre_act: bool = True, return_residual: bool = False, checkpoint_lvl: int = 0, heuristic: int = 0, - process_group: Optional[ProcessGroup] = None + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True ): dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] or (x.dtype == torch.float32 and torch.is_autocast_enabled())) if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda) and (bias2 is None or bias2.is_cuda) and dtype_eligible): return FusedDenseGeluDenseFunc.apply( - x, weight1, bias1, weight2, bias2, - save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group + x, weight1, bias1, weight2, bias2, save_pre_act, return_residual, + checkpoint_lvl, heuristic, process_group, sequence_parallel ) else: assert process_group is None @@ -418,7 +431,7 @@ class ParallelFusedDenseGeluDense(nn.Module): def __init__(self, in_features, hidden_features, out_features=None, process_group: ProcessGroup = None, bias1=True, bias2=True, - checkpoint_lvl=0, heuristic=0, device=None, dtype=None): + sequence_parallel=True, checkpoint_lvl=0, heuristic=0, device=None, dtype=None): """ process_group is required. We're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul, gelu, then matmul. @@ -441,6 +454,7 @@ class ParallelFusedDenseGeluDense(nn.Module): if out_features is None: out_features = in_features self.process_group = process_group + self.sequence_parallel = sequence_parallel self.checkpoint_lvl = checkpoint_lvl self.heuristic = heuristic self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, @@ -452,6 +466,9 @@ class ParallelFusedDenseGeluDense(nn.Module): out = fused_dense_gelu_dense_func( x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl, - heuristic=self.heuristic, process_group=self.process_group + heuristic=self.heuristic, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel ) - return reduce_scatter(out, self.process_group) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) diff --git a/flash_attn/utils/distributed.py b/flash_attn/utils/distributed.py index 16c8a28..09578e3 100644 --- a/flash_attn/utils/distributed.py +++ b/flash_attn/utils/distributed.py @@ -14,7 +14,7 @@ if "reduce_scatter_tensor" not in dir(torch.distributed): torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base -# Raw operation, oes does support autograd, but does support async +# Raw operation, does not support autograd, but does support async def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], @@ -24,7 +24,7 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = return output, handle -# Raw operation, oes does support autograd, but does support async +# Raw operation, does not support autograd, but does support async def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 @@ -36,6 +36,13 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo return output, handle +# Raw operation, does not support autograd, but does support async +def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + input_ = input_.contiguous() + handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) + return input_, handle + + class AllGatherFunc(torch.autograd.Function): """Gather the input from sequence parallel region and concatenate.""" @@ -74,12 +81,30 @@ class ReduceScatterFunc(torch.autograd.Function): reduce_scatter = ReduceScatterFunc.apply -def sync_sequence_parallel_params(model: torch.nn.Module, process_group: ProcessGroup): - # We want to iterate over parameters with _sequence_parallel=True in the same order, +class AllReduceFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_reduce_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + return grad_output, None + + +# Supports autograd, but does not support async +all_reduce = AllReduceFunc.apply + + +def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): + # We want to iterate over parameters with _shared_params=True in the same order, # as different ranks might have different number of parameters (e.g., only rank 0 has bias). - params_seqparallel = {name: p for name, p in model.named_parameters() - if getattr(p, '_sequence_parallel', False)} - for _, p in sorted(params_seqparallel.items()): + pamams_shared = {name: p for name, p in model.named_parameters() + if getattr(p, '_shared_params', False)} + for _, p in sorted(pamams_shared.items()): with torch.no_grad(): # Broadcast needs src to be global rank, not group rank torch.distributed.broadcast( @@ -94,8 +119,9 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc params_seqparallel = {name: p for name, p in model.named_parameters() if getattr(p, '_sequence_parallel', False)} grads = [p.grad for _, p in sorted(params_seqparallel.items())] - with torch.no_grad(): - coalesced = torch._utils._flatten_dense_tensors(grads) - torch.distributed.all_reduce(coalesced, group=process_group) - for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) + if grads: + with torch.no_grad(): + coalesced = torch._utils._flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=process_group) + for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) diff --git a/tests/models/test_gpt_parallel.py b/tests/models/test_gpt_parallel.py index bd91a1f..dd4451b 100644 --- a/tests/models/test_gpt_parallel.py +++ b/tests/models/test_gpt_parallel.py @@ -23,10 +23,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('sequence_parallel', [True, False]) +# @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize('has_pos_emb', [True, False]) # @pytest.mark.parametrize('has_pos_emb', [True]) @pytest.mark.parametrize('dim', [1024]) -def test_gpt_parallel(dim, has_pos_emb, world_size, dtype): +def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): head_dim = 64 assert dim % head_dim == 0 num_heads = dim // head_dim @@ -59,7 +61,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype): scale_attn_by_inverse_layer_idx=True, use_flash_attn=True, fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True, rotary_emb_fraction=0.0 if has_pos_emb else 0.5, - pad_vocab_size_multiple=8 * world_size) + pad_vocab_size_multiple=8 * world_size, + sequence_parallel=sequence_parallel) model_pt = GPTLMHeadModel(config, device=device) def init_layer_norm(module): @@ -75,16 +78,15 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype): torch.distributed.all_gather_into_tensor( sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group ) - sequence_parallel_nparams = sum(p.numel() for p in model.parameters() - if getattr(p, '_sequence_parallel', False)) - sequence_parallel_nparams_all = torch.empty(world_size, dtype=torch.long, device=device) + shared_nparams = sum(p.numel() for p in model.parameters() + if getattr(p, '_shared_params', False)) + shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device) torch.distributed.all_gather_into_tensor( - sequence_parallel_nparams_all, torch.tensor([sequence_parallel_nparams], device=device), - group=process_group + shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group ) - assert torch.all(sequence_parallel_nparams_all == sequence_parallel_nparams) - assert total_nparams == ((sharded_nparams_all - sequence_parallel_nparams_all).sum().item() - + sequence_parallel_nparams) + assert torch.all(shared_nparams_all == shared_nparams) + assert total_nparams == ((sharded_nparams_all - shared_nparams_all).sum().item() + + shared_nparams) # vocab_size has been rounded up here partition_vocab_size = config.vocab_size // world_size @@ -96,6 +98,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype): with torch.autocast(device_type='cuda', dtype=dtype): out = model(input_ids[:, :-1]).logits + if not sequence_parallel: + out = rearrange(out, 'b s d -> (b s) d') out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d') partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( diff --git a/tests/modules/test_block_parallel.py b/tests/modules/test_block_parallel.py index cf2397a..34a9cff 100644 --- a/tests/modules/test_block_parallel.py +++ b/tests/modules/test_block_parallel.py @@ -23,11 +23,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 @pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) -# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +# @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('sequence_parallel', [True, False]) +# @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize('dim', [1024]) -def test_block_parallel(dim, world_size, dtype): +def test_block_parallel(dim, sequence_parallel, world_size, dtype): head_dim = 64 assert dim % head_dim == 0 num_heads = dim // head_dim @@ -41,7 +43,7 @@ def test_block_parallel(dim, world_size, dtype): rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) - batch_size = 8 + batch_size = 2 seqlen = 1024 assert (batch_size * seqlen) % world_size == 0 x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, @@ -51,8 +53,12 @@ def test_block_parallel(dim, world_size, dtype): # as rank 0 will have an extra bias that changes the RNG. # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(x_pt) / 32 - x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() - residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_() + if sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_() + else: + x = x_pt.detach().clone().requires_grad_() + residual = residual_pt.detach().clone().requires_grad_() mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, device=device, dtype=dtype) @@ -69,12 +75,12 @@ def test_block_parallel(dim, world_size, dtype): mixer_cls = partial(ParallelMHA, num_heads=num_heads, process_group=parallel_state.get_tensor_model_parallel_group(), rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, - device=device, dtype=dtype) + sequence_parallel=sequence_parallel, device=device, dtype=dtype) mlp_cls = partial(ParallelFusedDenseGeluDense, hidden_features=4 * dim, process_group=parallel_state.get_tensor_model_parallel_group(), - device=device, dtype=dtype) + sequence_parallel=sequence_parallel, device=device, dtype=dtype) model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True, - sequence_parallel=True) + sequence_parallel=sequence_parallel, mark_shared_params=True) partition_dim = dim // world_size partition_hidden_dim = 4 * dim // world_size @@ -115,25 +121,34 @@ def test_block_parallel(dim, world_size, dtype): out_pt, out_residual_pt = [rearrange(x, 'b s d -> (b s) d') for x in [out_pt, out_residual_pt]] partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( - out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + out, + out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else out_pt, rtol=rtol, atol=atol ) assert torch.allclose( - out_residual, out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + out_residual, + out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else out_residual_pt, rtol=rtol, atol=atol ) - out_pt.backward(g) - out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) + (out_pt + 2 * out_residual_pt).backward(g) + (out + 2 * out_residual).backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else g) allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group()) parallel_state.destroy_model_parallel() assert torch.allclose( - x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], - rtol=rtol, atol=atol + x.grad, + x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else x_pt.grad, + rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small ) assert torch.allclose( - residual.grad, residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + residual.grad, + residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else residual_pt.grad, rtol=rtol, atol=atol ) # The error for d_weight and d_bias is quite a bit higher diff --git a/tests/modules/test_embedding_parallel.py b/tests/modules/test_embedding_parallel.py index d2de870..a6be31e 100644 --- a/tests/modules/test_embedding_parallel.py +++ b/tests/modules/test_embedding_parallel.py @@ -19,10 +19,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('sequence_parallel', [True, False]) +# @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize('has_pos_emb', [True, False]) # @pytest.mark.parametrize('has_pos_emb', [True]) @pytest.mark.parametrize('dim', [1024]) -def test_embedding_parallel(dim, world_size, has_pos_emb, dtype): +def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): vocab_size = 50264 seqlen = 2048 assert vocab_size % world_size == 0 @@ -46,7 +48,7 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype): device=device, dtype=dtype) model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0, parallel_state.get_tensor_model_parallel_group(), - device=device, dtype=dtype) + sequence_parallel=sequence_parallel, device=device, dtype=dtype) partition_vocab_size = vocab_size // world_size partition_dim = dim // world_size with torch.no_grad(): @@ -62,13 +64,16 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype): out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d') partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( - out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + out, + out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else out_pt, rtol=rtol, atol=atol ) g = torch.randn_like(out_pt) out_pt.backward(g) - out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) + out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else g) parallel_state.destroy_model_parallel() assert torch.allclose( diff --git a/tests/modules/test_mha_parallel.py b/tests/modules/test_mha_parallel.py index 2d3d01f..45afbd9 100644 --- a/tests/modules/test_mha_parallel.py +++ b/tests/modules/test_mha_parallel.py @@ -21,11 +21,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('sequence_parallel', [True, False]) +# @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize('head_dim', [64, 128]) # @pytest.mark.parametrize('head_dim', [64]) @pytest.mark.parametrize('embed_dim', [1024, 4096]) # @pytest.mark.parametrize('embed_dim', [1024]) -def test_mha_parallel(embed_dim, head_dim, world_size, dtype): +def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype): assert embed_dim % head_dim == 0 num_heads = embed_dim // head_dim assert num_heads % world_size == 0 @@ -38,7 +40,7 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype): rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) - batch_size = 8 + batch_size = 2 seqlen = 1024 assert (batch_size * seqlen) % world_size == 0 x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype, @@ -47,14 +49,17 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype): # as rank 0 will have an extra bias that changes the RNG. # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(x_pt) / 32 - x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + if sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + else: + x = x_pt.detach().clone().requires_grad_() model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, device=device, dtype=dtype) partition_dim = embed_dim // world_size model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(), rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, - device=device, dtype=dtype) + sequence_parallel=sequence_parallel, device=device, dtype=dtype) with torch.no_grad(): model.Wqkv.weight.copy_( @@ -75,17 +80,22 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype): out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d') partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( - out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + out, + out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else out_pt, rtol=rtol, atol=atol ) out_pt.backward(g) - out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) + out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else g) parallel_state.destroy_model_parallel() assert torch.allclose( - x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], - rtol=rtol, atol=atol + x.grad, + x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else x_pt.grad, + rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small ) # The error for d_weight and d_bias is quite a bit higher assert torch.allclose( diff --git a/tests/ops/test_fused_dense_parallel.py b/tests/ops/test_fused_dense_parallel.py index 3b19a71..9feff05 100644 --- a/tests/ops/test_fused_dense_parallel.py +++ b/tests/ops/test_fused_dense_parallel.py @@ -19,14 +19,15 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 @pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) -# @pytest.mark.parametrize('world_size', [8]) +# @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('sequence_parallel', [True, False]) +# @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize('has_bias', [True, False]) -# @pytest.mark.parametrize('has_bias', [True]) -@pytest.mark.parametrize('out_features', [1024, 4096]) -# @pytest.mark.parametrize('out_features', [1024]) -@pytest.mark.parametrize('in_features', [1024, 4096]) -# @pytest.mark.parametrize('in_features', [4096]) -def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtype): +# @pytest.mark.parametrize('has_bias', [False]) +@pytest.mark.parametrize('out_features', [1024]) +@pytest.mark.parametrize('in_features', [4096]) +def test_fused_linear_bias(in_features, out_features, has_bias, sequence_parallel, + world_size, dtype): assert out_features % world_size == 0 rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) if not torch.distributed.is_initialized(): @@ -37,18 +38,21 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) - batch_size = 8 + batch_size = 2 seqlen = 512 assert batch_size * seqlen % world_size == 0 x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True) - x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + if sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + else: + x = x_pt.detach().clone().requires_grad_() model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype) partition_out_features = out_features // world_size model = ColumnParallelLinear(in_features, out_features, parallel_state.get_tensor_model_parallel_group(), bias=has_bias, - device=device, dtype=dtype) + sequence_parallel=sequence_parallel, device=device, dtype=dtype) with torch.no_grad(): model.weight.copy_( model_pt.weight[rank * partition_out_features:(rank + 1) * partition_out_features] @@ -73,7 +77,9 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( - x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + x.grad, + x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else x_pt.grad, rtol=rtol, atol=atol ) # The error for d_weight and d_bias is quite a bit higher @@ -94,13 +100,14 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('sequence_parallel', [True, False]) +# @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize('has_bias2', [True, False]) # @pytest.mark.parametrize('has_bias2', [True]) -@pytest.mark.parametrize('out_features', [1024, 4096]) -# @pytest.mark.parametrize('out_features', [1024]) -@pytest.mark.parametrize('in_features', [1024, 4096]) -# @pytest.mark.parametrize('in_features', [1024]) -def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size, dtype): +@pytest.mark.parametrize('out_features', [4096]) +@pytest.mark.parametrize('in_features', [1024]) +def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_parallel, + world_size, dtype): assert out_features % world_size == 0 rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) if not torch.distributed.is_initialized(): @@ -111,7 +118,7 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) - batch_size = 8 + batch_size = 2 seqlen = 512 assert batch_size * seqlen % world_size == 0 x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype, @@ -120,7 +127,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size # as rank 0 will have an extra bias that changes the RNG. # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(x_pt) / 32 - x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + if sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + else: + x = x_pt.detach().clone().requires_grad_() model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device, @@ -129,7 +139,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size partition_in_features = in_features // world_size model = ParallelFusedDenseGeluDense(in_features, out_features, in_features, process_group=parallel_state.get_tensor_model_parallel_group(), - bias2=has_bias2 and rank == 0, device=device, dtype=dtype) + bias2=has_bias2 and rank == 0, + sequence_parallel=sequence_parallel, + device=device, dtype=dtype) with torch.no_grad(): model.fc1.weight.copy_( @@ -148,16 +160,21 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( - out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + out, + out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else out_pt, rtol=rtol, atol=atol ) out_pt.backward(g) - out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) + out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else g) parallel_state.destroy_model_parallel() assert torch.allclose( - x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + x.grad, + x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] + if sequence_parallel else x_pt.grad, rtol=rtol, atol=atol ) # The error for d_weight and d_bias is quite a bit higher