[TP] Implement TensorParallel without sequence parallel
This commit is contained in:
parent
ce26d3d73d
commit
93383bd55b
@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.utils.distributed import sync_sequence_parallel_params
|
||||
from flash_attn.utils.distributed import sync_shared_params
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
|
||||
@ -62,7 +62,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
mha_cls = MHA if process_group is None else ParallelMHA
|
||||
serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
|
||||
if process_group is None else {})
|
||||
parallel_kwargs = {'process_group': process_group} if process_group is not None else {}
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
|
||||
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
|
||||
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
|
||||
@ -99,7 +101,9 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
if FusedDenseGeluDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense
|
||||
parallel_kwargs = {'process_group': process_group} if process_group is not None else {}
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl,
|
||||
**parallel_kwargs, **factory_kwargs)
|
||||
elif fused_dense_sqrelu_dense:
|
||||
@ -113,13 +117,15 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
|
||||
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
||||
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
||||
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
|
||||
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
prenorm=True, resid_dropout=config.resid_pdrop,
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
sequence_parallel=process_group is not None)
|
||||
sequence_parallel=sequence_parallel and process_group is not None,
|
||||
mark_shared_params=process_group is not None)
|
||||
block.layer_idx = layer_idx
|
||||
return block
|
||||
|
||||
@ -180,6 +186,7 @@ class GPTModel(GPTPreTrainedModel):
|
||||
super().__init__(config)
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu']
|
||||
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
||||
@ -192,7 +199,8 @@ class GPTModel(GPTPreTrainedModel):
|
||||
else:
|
||||
self.embeddings = ParallelGPT2Embeddings(
|
||||
config.hidden_size, config.vocab_size, config.max_position_embeddings,
|
||||
process_group=process_group, **factory_kwargs
|
||||
process_group=process_group, sequence_parallel=self.sequence_parallel,
|
||||
**factory_kwargs
|
||||
)
|
||||
self.emb_drop = nn.Dropout(config.embd_pdrop)
|
||||
|
||||
@ -209,10 +217,13 @@ class GPTModel(GPTPreTrainedModel):
|
||||
# is the final layer norm.
|
||||
self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
|
||||
**factory_kwargs)
|
||||
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
||||
if process_group is not None:
|
||||
for p in self.ln_0.parameters():
|
||||
p._sequence_parallel = True
|
||||
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
||||
p._shared_params = True
|
||||
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
|
||||
if self.sequence_parallel:
|
||||
p._sequence_parallel = True
|
||||
|
||||
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
|
||||
**factory_kwargs)
|
||||
@ -224,14 +235,14 @@ class GPTModel(GPTPreTrainedModel):
|
||||
|
||||
def tie_weights(self):
|
||||
if self.process_group is not None:
|
||||
sync_sequence_parallel_params(self, self.process_group)
|
||||
sync_shared_params(self, self.process_group)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None):
|
||||
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
|
||||
# dimensions so that we can split on it easily, in case of small batch size.
|
||||
# Only the attention layers need to know the seqlen.
|
||||
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
|
||||
if self.process_group is not None else {})
|
||||
if self.process_group is not None and self.sequence_parallel else {})
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
if not self.fused_dropout_add_ln:
|
||||
@ -243,7 +254,8 @@ class GPTModel(GPTPreTrainedModel):
|
||||
self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
|
||||
residual_in_fp32=True
|
||||
)
|
||||
mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {})
|
||||
mixer_kwargs = ({'seqlen': input_ids.shape[1]}
|
||||
if self.process_group is not None and self.sequence_parallel else {})
|
||||
if inference_params is not None:
|
||||
mixer_kwargs['inference_params'] = inference_params
|
||||
for layer in self.layers:
|
||||
@ -263,8 +275,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
else:
|
||||
if ColumnParallelLinear is None:
|
||||
raise ImportError('fused_dense_lib is not installed')
|
||||
self.lm_head = ColumnParallelLinear(config.n_embd, config.vocab_size, process_group,
|
||||
bias=False, **factory_kwargs)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.n_embd, config.vocab_size, process_group, bias=False,
|
||||
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
|
||||
)
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
@ -273,7 +287,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
|
||||
if self.process_group is not None:
|
||||
sync_sequence_parallel_params(self, self.process_group)
|
||||
sync_shared_params(self, self.process_group)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None):
|
||||
"""
|
||||
|
||||
@ -23,7 +23,8 @@ class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
|
||||
dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0.,
|
||||
fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False):
|
||||
fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False,
|
||||
mark_shared_params=False):
|
||||
"""
|
||||
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
|
||||
This is for performance reason: for post-norm architecture, returning the input allows us
|
||||
@ -51,6 +52,12 @@ class Block(nn.Module):
|
||||
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
|
||||
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
|
||||
|
||||
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
||||
# then the input to each worker in the tensor parallel group will be different.
|
||||
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
||||
# For now this is not an issue because we always use sequence_parallel=True during training
|
||||
# and only use sequence_parallel=False during inference.
|
||||
|
||||
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
||||
if sequence_parallel:
|
||||
for p in self.norm1.parameters():
|
||||
@ -58,6 +65,13 @@ class Block(nn.Module):
|
||||
if hasattr(self, 'norm2'):
|
||||
for p in self.norm2.parameters():
|
||||
p._sequence_parallel = True
|
||||
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
||||
if mark_shared_params:
|
||||
for p in self.norm1.parameters():
|
||||
p._shared_params = True
|
||||
if hasattr(self, 'norm2'):
|
||||
for p in self.norm2.parameters():
|
||||
p._shared_params = True
|
||||
|
||||
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
|
||||
mixer_kwargs=None):
|
||||
|
||||
@ -6,7 +6,7 @@ from torch import Tensor
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.utils.distributed import reduce_scatter
|
||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||
|
||||
|
||||
class GPT2Embeddings(nn.Module):
|
||||
@ -130,13 +130,14 @@ class ColumnParallelEmbedding(nn.Embedding):
|
||||
class ParallelGPT2Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
|
||||
padding_idx=None, device=None, dtype=None):
|
||||
padding_idx=None, sequence_parallel=True, device=None, dtype=None):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
|
||||
**factory_kwargs
|
||||
@ -167,4 +168,5 @@ class ParallelGPT2Embeddings(nn.Module):
|
||||
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
|
||||
if combine_batch_seqlen_dim:
|
||||
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
|
||||
return embeddings if world_size <= 1 else reduce_scatter(embeddings, self.process_group)
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
||||
|
||||
@ -497,11 +497,10 @@ class ParallelMHA(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
|
||||
softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0,
|
||||
rotary_emb_scale_base=0,
|
||||
use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None:
|
||||
rotary_emb_scale_base=0, use_flash_attn=False, checkpointing=False,
|
||||
sequence_parallel=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.process_group = process_group
|
||||
self.embed_dim = embed_dim
|
||||
self.causal = causal
|
||||
self.layer_idx = layer_idx
|
||||
@ -521,12 +520,13 @@ class ParallelMHA(nn.Module):
|
||||
if ColumnParallelLinear is None or RowParallelLinear is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
|
||||
**factory_kwargs)
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout)
|
||||
# output projection always have the bias (for now)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, **factory_kwargs)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
|
||||
def forward(self, x, seqlen=None, **kwargs):
|
||||
"""
|
||||
|
||||
@ -15,26 +15,29 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
|
||||
from flash_attn.ops.gelu_activation import gelu_bwd
|
||||
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, reduce_scatter
|
||||
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
|
||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||
|
||||
|
||||
class FusedDenseFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None):
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
|
||||
sequence_parallel=True):
|
||||
"""
|
||||
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather_raw of x before doing the matmul.
|
||||
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
||||
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
||||
"""
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.sequence_parallel = sequence_parallel
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
x = x.contiguous()
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
# We want to kick off the all_gather early, before weight dtype conversion
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
@ -44,7 +47,7 @@ class FusedDenseFunc(torch.autograd.Function):
|
||||
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
||||
weight = weight.contiguous()
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
@ -66,9 +69,10 @@ class FusedDenseFunc(torch.autograd.Function):
|
||||
grad_input, = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
if ctx.compute_weight_gradient:
|
||||
x, weight = ctx.saved_tensors
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
total_x = x
|
||||
@ -86,13 +90,13 @@ class FusedDenseFunc(torch.autograd.Function):
|
||||
grad_output, weight)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group,
|
||||
async_op=True)
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||
else:
|
||||
grad_input = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
assert ctx.compute_weight_gradient
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||
@ -102,15 +106,17 @@ class FusedDenseFunc(torch.autograd.Function):
|
||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||
if process_group is not None and ctx.needs_input_grad[0]:
|
||||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
|
||||
return_residual: bool = False, process_group: Optional[ProcessGroup] = None):
|
||||
return_residual: bool = False, process_group: Optional[ProcessGroup] = None,
|
||||
sequence_parallel: bool = True):
|
||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
|
||||
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group,
|
||||
sequence_parallel)
|
||||
else:
|
||||
assert process_group is None
|
||||
out = F.linear(x, weight, bias)
|
||||
@ -136,7 +142,7 @@ class FusedDense(nn.Linear):
|
||||
class ColumnParallelLinear(nn.Linear):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
|
||||
bias: bool = True, device=None, dtype=None) -> None:
|
||||
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if out_features % world_size != 0:
|
||||
raise ValueError(f'out_features ({out_features}) must be divisible by '
|
||||
@ -144,19 +150,20 @@ class ColumnParallelLinear(nn.Linear):
|
||||
super().__init__(in_features, out_features // world_size, bias=bias,
|
||||
device=device, dtype=dtype)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
We're doing Tensor Parallel with sequence parallelism: we do an all_gather of
|
||||
x before doing the matmul.
|
||||
"""
|
||||
return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group,
|
||||
sequence_parallel=self.sequence_parallel)
|
||||
|
||||
|
||||
class RowParallelLinear(nn.Linear):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
|
||||
bias: bool = True, device=None, dtype=None) -> None:
|
||||
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
rank = torch.distributed.get_rank(process_group)
|
||||
if in_features % world_size != 0:
|
||||
@ -166,6 +173,7 @@ class RowParallelLinear(nn.Linear):
|
||||
super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
|
||||
device=device, dtype=dtype)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@ -173,7 +181,8 @@ class RowParallelLinear(nn.Linear):
|
||||
a reduce_scatter of the result.
|
||||
"""
|
||||
out = fused_dense_func(x, self.weight, self.bias)
|
||||
return reduce_scatter(out, self.process_group)
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return reduce_fn(out, self.process_group)
|
||||
|
||||
|
||||
class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
@ -181,10 +190,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False,
|
||||
checkpoint_lvl=0, heuristic=0, process_group=None):
|
||||
checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True):
|
||||
"""
|
||||
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul.
|
||||
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
||||
with sequence parallelism: we do an all_gather of x before doing the matmul.
|
||||
If sequence_parallel=False, then the input is already gathered.
|
||||
|
||||
checkpoint_lvl:
|
||||
0: no recomputation in the bwd
|
||||
@ -197,13 +207,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.sequence_parallel = sequence_parallel
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
ctx.heuristic = heuristic
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
x = x.contiguous()
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
# We want to kick off the all_gather early, before weight dtype conversion
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
@ -218,7 +229,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
bias1 = bias1.contiguous() if bias1 is not None else None
|
||||
weight2 = weight2.contiguous()
|
||||
bias2 = bias2.contiguous() if bias2 is not None else None
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
@ -257,13 +268,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
grad_input, = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
x, weight1, weight2, *rest = ctx.saved_tensors
|
||||
if process_group is None:
|
||||
if process_group is None or not sequence_parallel:
|
||||
total_x = x
|
||||
batch_shape = grad_output.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
if checkpoint_lvl in [0, 1]:
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
if checkpoint_lvl == 0:
|
||||
gelu_in, output1 = rest
|
||||
@ -272,7 +284,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
elif checkpoint_lvl == 2:
|
||||
bias1, = rest
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, _ = all_gather_raw(x, process_group)
|
||||
if ctx.heuristic == -1:
|
||||
gelu_in = F.linear(total_x, weight1, bias1)
|
||||
@ -314,13 +326,13 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
grad_gelu, weight1)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group,
|
||||
async_op=True)
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||
else:
|
||||
grad_input = None
|
||||
if ctx.heuristic == -1:
|
||||
if ctx.needs_input_grad[1]:
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu,
|
||||
@ -331,7 +343,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
|
||||
else:
|
||||
if ctx.needs_input_grad[1]:
|
||||
if process_group is not None:
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
grad_weight1 = F.linear(grad_gelu.t(),
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]).t())
|
||||
@ -340,7 +352,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
if process_group is not None and ctx.needs_input_grad[0]:
|
||||
handle_grad_input.wait()
|
||||
return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
|
||||
None, None, None, None, None)
|
||||
None, None, None, None, None, None)
|
||||
|
||||
|
||||
def fused_dense_gelu_dense_func(
|
||||
@ -348,15 +360,16 @@ def fused_dense_gelu_dense_func(
|
||||
bias2: Optional[Tensor] = None,
|
||||
save_pre_act: bool = True, return_residual: bool = False,
|
||||
checkpoint_lvl: int = 0, heuristic: int = 0,
|
||||
process_group: Optional[ProcessGroup] = None
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
sequence_parallel: bool = True
|
||||
):
|
||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
|
||||
and (bias2 is None or bias2.is_cuda) and dtype_eligible):
|
||||
return FusedDenseGeluDenseFunc.apply(
|
||||
x, weight1, bias1, weight2, bias2,
|
||||
save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group
|
||||
x, weight1, bias1, weight2, bias2, save_pre_act, return_residual,
|
||||
checkpoint_lvl, heuristic, process_group, sequence_parallel
|
||||
)
|
||||
else:
|
||||
assert process_group is None
|
||||
@ -418,7 +431,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features, out_features=None,
|
||||
process_group: ProcessGroup = None, bias1=True, bias2=True,
|
||||
checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
|
||||
sequence_parallel=True, checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
|
||||
"""
|
||||
process_group is required. We're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
||||
@ -441,6 +454,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
|
||||
if out_features is None:
|
||||
out_features = in_features
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
self.heuristic = heuristic
|
||||
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
|
||||
@ -452,6 +466,9 @@ class ParallelFusedDenseGeluDense(nn.Module):
|
||||
out = fused_dense_gelu_dense_func(
|
||||
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
|
||||
save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl,
|
||||
heuristic=self.heuristic, process_group=self.process_group
|
||||
heuristic=self.heuristic,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=self.sequence_parallel
|
||||
)
|
||||
return reduce_scatter(out, self.process_group)
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return reduce_fn(out, self.process_group)
|
||||
|
||||
@ -14,7 +14,7 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
||||
|
||||
|
||||
# Raw operation, oes does support autograd, but does support async
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
|
||||
@ -24,7 +24,7 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
|
||||
return output, handle
|
||||
|
||||
|
||||
# Raw operation, oes does support autograd, but does support async
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
@ -36,6 +36,13 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
|
||||
return output, handle
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
input_ = input_.contiguous()
|
||||
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
||||
return input_, handle
|
||||
|
||||
|
||||
class AllGatherFunc(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatenate."""
|
||||
|
||||
@ -74,12 +81,30 @@ class ReduceScatterFunc(torch.autograd.Function):
|
||||
reduce_scatter = ReduceScatterFunc.apply
|
||||
|
||||
|
||||
def sync_sequence_parallel_params(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
||||
class AllReduceFunc(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = all_reduce_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
all_reduce = AllReduceFunc.apply
|
||||
|
||||
|
||||
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _shared_params=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
params_seqparallel = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_sequence_parallel', False)}
|
||||
for _, p in sorted(params_seqparallel.items()):
|
||||
pamams_shared = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_shared_params', False)}
|
||||
for _, p in sorted(pamams_shared.items()):
|
||||
with torch.no_grad():
|
||||
# Broadcast needs src to be global rank, not group rank
|
||||
torch.distributed.broadcast(
|
||||
@ -94,8 +119,9 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc
|
||||
params_seqparallel = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_sequence_parallel', False)}
|
||||
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
||||
with torch.no_grad():
|
||||
coalesced = torch._utils._flatten_dense_tensors(grads)
|
||||
torch.distributed.all_reduce(coalesced, group=process_group)
|
||||
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
if grads:
|
||||
with torch.no_grad():
|
||||
coalesced = torch._utils._flatten_dense_tensors(grads)
|
||||
torch.distributed.all_reduce(coalesced, group=process_group)
|
||||
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
@ -23,10 +23,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('sequence_parallel', [True, False])
|
||||
# @pytest.mark.parametrize('sequence_parallel', [False])
|
||||
@pytest.mark.parametrize('has_pos_emb', [True, False])
|
||||
# @pytest.mark.parametrize('has_pos_emb', [True])
|
||||
@pytest.mark.parametrize('dim', [1024])
|
||||
def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
|
||||
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
|
||||
head_dim = 64
|
||||
assert dim % head_dim == 0
|
||||
num_heads = dim // head_dim
|
||||
@ -59,7 +61,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
|
||||
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
|
||||
fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True,
|
||||
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
|
||||
pad_vocab_size_multiple=8 * world_size)
|
||||
pad_vocab_size_multiple=8 * world_size,
|
||||
sequence_parallel=sequence_parallel)
|
||||
model_pt = GPTLMHeadModel(config, device=device)
|
||||
|
||||
def init_layer_norm(module):
|
||||
@ -75,16 +78,15 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
|
||||
)
|
||||
sequence_parallel_nparams = sum(p.numel() for p in model.parameters()
|
||||
if getattr(p, '_sequence_parallel', False))
|
||||
sequence_parallel_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
|
||||
shared_nparams = sum(p.numel() for p in model.parameters()
|
||||
if getattr(p, '_shared_params', False))
|
||||
shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
sequence_parallel_nparams_all, torch.tensor([sequence_parallel_nparams], device=device),
|
||||
group=process_group
|
||||
shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
|
||||
)
|
||||
assert torch.all(sequence_parallel_nparams_all == sequence_parallel_nparams)
|
||||
assert total_nparams == ((sharded_nparams_all - sequence_parallel_nparams_all).sum().item()
|
||||
+ sequence_parallel_nparams)
|
||||
assert torch.all(shared_nparams_all == shared_nparams)
|
||||
assert total_nparams == ((sharded_nparams_all - shared_nparams_all).sum().item()
|
||||
+ shared_nparams)
|
||||
|
||||
# vocab_size has been rounded up here
|
||||
partition_vocab_size = config.vocab_size // world_size
|
||||
@ -96,6 +98,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
|
||||
|
||||
with torch.autocast(device_type='cuda', dtype=dtype):
|
||||
out = model(input_ids[:, :-1]).logits
|
||||
if not sequence_parallel:
|
||||
out = rearrange(out, 'b s d -> (b s) d')
|
||||
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d')
|
||||
partition_batch_dim = batch_size * seqlen // world_size
|
||||
assert torch.allclose(
|
||||
|
||||
@ -23,11 +23,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('sequence_parallel', [True, False])
|
||||
# @pytest.mark.parametrize('sequence_parallel', [False])
|
||||
@pytest.mark.parametrize('dim', [1024])
|
||||
def test_block_parallel(dim, world_size, dtype):
|
||||
def test_block_parallel(dim, sequence_parallel, world_size, dtype):
|
||||
head_dim = 64
|
||||
assert dim % head_dim == 0
|
||||
num_heads = dim // head_dim
|
||||
@ -41,7 +43,7 @@ def test_block_parallel(dim, world_size, dtype):
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
batch_size = 2
|
||||
seqlen = 1024
|
||||
assert (batch_size * seqlen) % world_size == 0
|
||||
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype,
|
||||
@ -51,8 +53,12 @@ def test_block_parallel(dim, world_size, dtype):
|
||||
# as rank 0 will have an extra bias that changes the RNG.
|
||||
# If we don't divide by batch_size, the gradient gets a bit too large.
|
||||
g = torch.randn_like(x_pt) / 32
|
||||
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_()
|
||||
if sequence_parallel:
|
||||
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_()
|
||||
else:
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
residual = residual_pt.detach().clone().requires_grad_()
|
||||
|
||||
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2),
|
||||
use_flash_attn=True, device=device, dtype=dtype)
|
||||
@ -69,12 +75,12 @@ def test_block_parallel(dim, world_size, dtype):
|
||||
mixer_cls = partial(ParallelMHA, num_heads=num_heads,
|
||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
|
||||
device=device, dtype=dtype)
|
||||
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
|
||||
mlp_cls = partial(ParallelFusedDenseGeluDense, hidden_features=4 * dim,
|
||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||
device=device, dtype=dtype)
|
||||
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
|
||||
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True,
|
||||
sequence_parallel=True)
|
||||
sequence_parallel=sequence_parallel, mark_shared_params=True)
|
||||
|
||||
partition_dim = dim // world_size
|
||||
partition_hidden_dim = 4 * dim // world_size
|
||||
@ -115,25 +121,34 @@ def test_block_parallel(dim, world_size, dtype):
|
||||
out_pt, out_residual_pt = [rearrange(x, 'b s d -> (b s) d') for x in [out_pt, out_residual_pt]]
|
||||
partition_batch_dim = batch_size * seqlen // world_size
|
||||
assert torch.allclose(
|
||||
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
out,
|
||||
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else out_pt,
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
assert torch.allclose(
|
||||
out_residual, out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
out_residual,
|
||||
out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else out_residual_pt,
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
out_pt.backward(g)
|
||||
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
|
||||
(out_pt + 2 * out_residual_pt).backward(g)
|
||||
(out + 2 * out_residual).backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else g)
|
||||
allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
assert torch.allclose(
|
||||
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
rtol=rtol, atol=atol
|
||||
x.grad,
|
||||
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else x_pt.grad,
|
||||
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
|
||||
)
|
||||
assert torch.allclose(
|
||||
residual.grad, residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
residual.grad,
|
||||
residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else residual_pt.grad,
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
# The error for d_weight and d_bias is quite a bit higher
|
||||
|
||||
@ -19,10 +19,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('sequence_parallel', [True, False])
|
||||
# @pytest.mark.parametrize('sequence_parallel', [False])
|
||||
@pytest.mark.parametrize('has_pos_emb', [True, False])
|
||||
# @pytest.mark.parametrize('has_pos_emb', [True])
|
||||
@pytest.mark.parametrize('dim', [1024])
|
||||
def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
|
||||
def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
|
||||
vocab_size = 50264
|
||||
seqlen = 2048
|
||||
assert vocab_size % world_size == 0
|
||||
@ -46,7 +48,7 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
|
||||
device=device, dtype=dtype)
|
||||
model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0,
|
||||
parallel_state.get_tensor_model_parallel_group(),
|
||||
device=device, dtype=dtype)
|
||||
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
|
||||
partition_vocab_size = vocab_size // world_size
|
||||
partition_dim = dim // world_size
|
||||
with torch.no_grad():
|
||||
@ -62,13 +64,16 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
|
||||
out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d')
|
||||
partition_batch_dim = batch_size * seqlen // world_size
|
||||
assert torch.allclose(
|
||||
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
out,
|
||||
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else out_pt,
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
g = torch.randn_like(out_pt)
|
||||
out_pt.backward(g)
|
||||
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
|
||||
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else g)
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
assert torch.allclose(
|
||||
|
||||
@ -21,11 +21,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('sequence_parallel', [True, False])
|
||||
# @pytest.mark.parametrize('sequence_parallel', [False])
|
||||
@pytest.mark.parametrize('head_dim', [64, 128])
|
||||
# @pytest.mark.parametrize('head_dim', [64])
|
||||
@pytest.mark.parametrize('embed_dim', [1024, 4096])
|
||||
# @pytest.mark.parametrize('embed_dim', [1024])
|
||||
def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
|
||||
def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
|
||||
assert embed_dim % head_dim == 0
|
||||
num_heads = embed_dim // head_dim
|
||||
assert num_heads % world_size == 0
|
||||
@ -38,7 +40,7 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
batch_size = 2
|
||||
seqlen = 1024
|
||||
assert (batch_size * seqlen) % world_size == 0
|
||||
x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype,
|
||||
@ -47,14 +49,17 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
|
||||
# as rank 0 will have an extra bias that changes the RNG.
|
||||
# If we don't divide by batch_size, the gradient gets a bit too large.
|
||||
g = torch.randn_like(x_pt) / 32
|
||||
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
if sequence_parallel:
|
||||
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
else:
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
|
||||
model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2),
|
||||
use_flash_attn=True, device=device, dtype=dtype)
|
||||
partition_dim = embed_dim // world_size
|
||||
model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(),
|
||||
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
|
||||
device=device, dtype=dtype)
|
||||
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
model.Wqkv.weight.copy_(
|
||||
@ -75,17 +80,22 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
|
||||
out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d')
|
||||
partition_batch_dim = batch_size * seqlen // world_size
|
||||
assert torch.allclose(
|
||||
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
out,
|
||||
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else out_pt,
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
out_pt.backward(g)
|
||||
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
|
||||
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else g)
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
assert torch.allclose(
|
||||
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
rtol=rtol, atol=atol
|
||||
x.grad,
|
||||
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else x_pt.grad,
|
||||
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
|
||||
)
|
||||
# The error for d_weight and d_bias is quite a bit higher
|
||||
assert torch.allclose(
|
||||
|
||||
@ -19,14 +19,15 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('sequence_parallel', [True, False])
|
||||
# @pytest.mark.parametrize('sequence_parallel', [False])
|
||||
@pytest.mark.parametrize('has_bias', [True, False])
|
||||
# @pytest.mark.parametrize('has_bias', [True])
|
||||
@pytest.mark.parametrize('out_features', [1024, 4096])
|
||||
# @pytest.mark.parametrize('out_features', [1024])
|
||||
@pytest.mark.parametrize('in_features', [1024, 4096])
|
||||
# @pytest.mark.parametrize('in_features', [4096])
|
||||
def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtype):
|
||||
# @pytest.mark.parametrize('has_bias', [False])
|
||||
@pytest.mark.parametrize('out_features', [1024])
|
||||
@pytest.mark.parametrize('in_features', [4096])
|
||||
def test_fused_linear_bias(in_features, out_features, has_bias, sequence_parallel,
|
||||
world_size, dtype):
|
||||
assert out_features % world_size == 0
|
||||
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
|
||||
if not torch.distributed.is_initialized():
|
||||
@ -37,18 +38,21 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
batch_size = 2
|
||||
seqlen = 512
|
||||
assert batch_size * seqlen % world_size == 0
|
||||
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
if sequence_parallel:
|
||||
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
else:
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
|
||||
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
|
||||
partition_out_features = out_features // world_size
|
||||
model = ColumnParallelLinear(in_features, out_features,
|
||||
parallel_state.get_tensor_model_parallel_group(), bias=has_bias,
|
||||
device=device, dtype=dtype)
|
||||
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
|
||||
with torch.no_grad():
|
||||
model.weight.copy_(
|
||||
model_pt.weight[rank * partition_out_features:(rank + 1) * partition_out_features]
|
||||
@ -73,7 +77,9 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
|
||||
|
||||
partition_batch_dim = batch_size * seqlen // world_size
|
||||
assert torch.allclose(
|
||||
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
x.grad,
|
||||
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else x_pt.grad,
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
# The error for d_weight and d_bias is quite a bit higher
|
||||
@ -94,13 +100,14 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('sequence_parallel', [True, False])
|
||||
# @pytest.mark.parametrize('sequence_parallel', [False])
|
||||
@pytest.mark.parametrize('has_bias2', [True, False])
|
||||
# @pytest.mark.parametrize('has_bias2', [True])
|
||||
@pytest.mark.parametrize('out_features', [1024, 4096])
|
||||
# @pytest.mark.parametrize('out_features', [1024])
|
||||
@pytest.mark.parametrize('in_features', [1024, 4096])
|
||||
# @pytest.mark.parametrize('in_features', [1024])
|
||||
def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size, dtype):
|
||||
@pytest.mark.parametrize('out_features', [4096])
|
||||
@pytest.mark.parametrize('in_features', [1024])
|
||||
def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_parallel,
|
||||
world_size, dtype):
|
||||
assert out_features % world_size == 0
|
||||
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
|
||||
if not torch.distributed.is_initialized():
|
||||
@ -111,7 +118,7 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
batch_size = 2
|
||||
seqlen = 512
|
||||
assert batch_size * seqlen % world_size == 0
|
||||
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
|
||||
@ -120,7 +127,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
|
||||
# as rank 0 will have an extra bias that changes the RNG.
|
||||
# If we don't divide by batch_size, the gradient gets a bit too large.
|
||||
g = torch.randn_like(x_pt) / 32
|
||||
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
if sequence_parallel:
|
||||
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
else:
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
|
||||
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
|
||||
@ -129,7 +139,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
|
||||
partition_in_features = in_features // world_size
|
||||
model = ParallelFusedDenseGeluDense(in_features, out_features, in_features,
|
||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||
bias2=has_bias2 and rank == 0, device=device, dtype=dtype)
|
||||
bias2=has_bias2 and rank == 0,
|
||||
sequence_parallel=sequence_parallel,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
model.fc1.weight.copy_(
|
||||
@ -148,16 +160,21 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
|
||||
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
|
||||
partition_batch_dim = batch_size * seqlen // world_size
|
||||
assert torch.allclose(
|
||||
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
out,
|
||||
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else out_pt,
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
out_pt.backward(g)
|
||||
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
|
||||
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else g)
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
assert torch.allclose(
|
||||
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
x.grad,
|
||||
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else x_pt.grad,
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
# The error for d_weight and d_bias is quite a bit higher
|
||||
|
||||
Loading…
Reference in New Issue
Block a user