split/merge into different files tp and dp

This commit is contained in:
ferdinand.mom 2024-11-04 16:26:11 +00:00
parent db926026a6
commit 77e85fe490
7 changed files with 379 additions and 385 deletions

View File

@ -1,10 +1,13 @@
import contextlib
import torch
import torch.distributed as dist
import contextlib
from torch import nn
from torch.autograd import Variable
from picotron.data_parallel.bucket import BucketManager
import picotron.process_group_manager as pgm
class DataParallel(nn.Module):
class DataParallelNaive(nn.Module):
def __init__(self, module):
"""
Initializes the DataParallel wrapper for a given module.
@ -49,4 +52,115 @@ class DataParallel(nn.Module):
"""
self.require_backward_grad_sync = False
yield
self.require_backward_grad_sync = True
self.require_backward_grad_sync = True
class DataParallelBucket(nn.Module):
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
"""
Initialize the DataParallelBucket module.
Args:
module (nn.Module): The model to be parallelized.
process_group: The process group for gradient synchronization, which can be either
a data parallel group or a context parallel group.
bucket_cap_mb (int, optional): The maximum size of each gradient synchronization bucket in megabytes.
Defaults to 25 MB.
grad_type (torch.dtype, optional): The data type of gradients, defaulting to float32.
"""
super().__init__()
self.module = module
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.cp_dp_group, bucket_size, grad_type)
self.register_backward_hook()
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def backward(self, input_tensor, output_tensor, output_tensor_grad):
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
def get_flops(self, *args, **kwargs):
return self.module.get_flops(*args, **kwargs)
def register_backward_hook(self):
"""
Registers a backward hook to manually accumulate and synchronize gradients.
This hook serves two main purposes:
1. PyTorch does not natively support gradient accumulation with mixed precision.
2. After gradient accumulation, it flags parameters as ready for synchronization.
The gradient accumulation functions are stored to prevent them from going out of scope.
References:
- https://github.com/NVIDIA/Megatron-LM/issues/690
- https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html
- https://arxiv.org/abs/2006.15704 (page 5)
"""
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc_fn)
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
"""
Creates the a hook for each parameter to handle gradient accumulation and synchronization.
"""
def param_hook(*unused):
"""
The hook called after the gradient is ready. It performs the following:
1. Accumulates the gradient into the main gradient.
2. Adds a post-backward callback to wait for gradient synchronization completion.
3. Marks the parameter as ready for synchronization.
"""
if param.requires_grad:
assert param.grad is not None
param.main_grad.add_(param.grad.data) # accumulate the gradients
param.grad = None
# skip the gradient synchronization (gradient accumulation/PP micro batches)
if self.require_backward_grad_sync:
# Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
# Callback is executed after the backward pass. It should be added per backward pass.
if not self._post_backward_callback_set:
Variable._execution_engine.queue_callback(self._post_backward)
self._post_backward_callback_set = True
# mark the parameter as ready for gradient synchronization.
bucket_manager.mark_param_as_ready(param)
return param_hook
@contextlib.contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronization."""
self.require_backward_grad_sync = False
yield
self.require_backward_grad_sync = True
def _post_backward(self):
"""
A post-backward callback that waits for gradient synchronization to finish, then copies
the synchronized gradients back to the parameters' grad attribute.
This method is called after the backward pass and before the optimizer step.
"""
self.bucket_manager.wait()
self._post_backward_callback_set = False
# copy to params.grad so we can use the optimizer to update the parameters
for p in self.module.parameters():
if p.requires_grad:
p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type.
def reset(self):
"""
Reset the bucket manager and zero out gradients in the model
"""
self.bucket_manager.reset()

View File

@ -1,118 +0,0 @@
import torch
import contextlib
from torch import nn
from torch.autograd import Variable
from picotron.data_parallel.bucket import BucketManager
import picotron.process_group_manager as pgm
class DataParallel(nn.Module):
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
"""
Initialize the DataParallel module.
Args:
module (nn.Module): The model to be parallelized.
process_group: The process group for gradient synchronization, which can be either
a data parallel group or a context parallel group.
bucket_cap_mb (int, optional): The maximum size of each gradient synchronization bucket in megabytes.
Defaults to 25 MB.
grad_type (torch.dtype, optional): The data type of gradients, defaulting to float32.
"""
super().__init__()
self.module = module
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.cp_dp_group, bucket_size, grad_type)
self.register_backward_hook()
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def backward(self, input_tensor, output_tensor, output_tensor_grad):
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
def get_flops(self, *args, **kwargs):
return self.module.get_flops(*args, **kwargs)
def register_backward_hook(self):
"""
Registers a backward hook to manually accumulate and synchronize gradients.
This hook serves two main purposes:
1. PyTorch does not natively support gradient accumulation with mixed precision.
2. After gradient accumulation, it flags parameters as ready for synchronization.
The gradient accumulation functions are stored to prevent them from going out of scope.
References:
- https://github.com/NVIDIA/Megatron-LM/issues/690
- https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html
- https://arxiv.org/abs/2006.15704 (page 5)
"""
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc_fn)
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
"""
Creates the a hook for each parameter to handle gradient accumulation and synchronization.
"""
def param_hook(*unused):
"""
The hook called after the gradient is ready. It performs the following:
1. Accumulates the gradient into the main gradient.
2. Adds a post-backward callback to wait for gradient synchronization completion.
3. Marks the parameter as ready for synchronization.
"""
if param.requires_grad:
assert param.grad is not None
param.main_grad.add_(param.grad.data) # accumulate the gradients
param.grad = None
# skip the gradient synchronization (gradient accumulation/PP micro batches)
if self.require_backward_grad_sync:
# Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
# Callback is executed after the backward pass. It should be added per backward pass.
if not self._post_backward_callback_set:
Variable._execution_engine.queue_callback(self._post_backward)
self._post_backward_callback_set = True
# mark the parameter as ready for gradient synchronization.
bucket_manager.mark_param_as_ready(param)
return param_hook
@contextlib.contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronization."""
self.require_backward_grad_sync = False
yield
self.require_backward_grad_sync = True
def _post_backward(self):
"""
A post-backward callback that waits for gradient synchronization to finish, then copies
the synchronized gradients back to the parameters' grad attribute.
This method is called after the backward pass and before the optimizer step.
"""
self.bucket_manager.wait()
self._post_backward_callback_set = False
# copy to params.grad so we can use the optimizer to update the parameters
for p in self.module.parameters():
if p.requires_grad:
p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type.
def reset(self):
"""
Reset the bucket manager and zero out gradients in the model
"""
self.bucket_manager.reset()

View File

@ -1,256 +0,0 @@
"""
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from picotron.tensor_parallel.utils import VocabUtility
import torch
import math
import torch.nn.init as init
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from typing import Callable, Optional
import picotron.process_group_manager as pgm
from picotron.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
def initialize_weight_tensor(weight, vocab_embedding=False):
"""
Initialize the weight tensor with the default initialization method in PyTorch
If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features.
If it's a vocab embedding, it uses a normal distribution N(0, 1).
"""
if not vocab_embedding:
# Get the in_features from the shape of the weight tensor
_, in_features = weight.shape
# Calculate k and the uniform bounds
k = 1 / in_features
bound = math.sqrt(k)
# Initialize weights with U(-sqrt(k), sqrt(k))
torch.nn.init.uniform_(weight, -bound, bound)
else:
# Initialize Vocab embedding with N(0, 1)
torch.nn.init.normal_(weight, mean=0.0, std=1.0)
def _initialize_affine_weight(
weight: torch.Tensor,
out_features: int,
in_features: int,
per_partition_size: int,
partition_dim: int,
init_method: Callable[[torch.Tensor], torch.Tensor]
) -> Optional[torch.Tensor]:
"""
Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
Args:
weight: The weight tensor that will be initialized for the current partition.
out_features: second dimension of weight matrix W.
in_features: first dimension of weight matrix W.
per_partition_size: The size of the weight partition assigned to each process.
partition_dim: The dimension along which the weight matrix is split for parallelism.
init_method: The method used to initialize the weight values.
"""
# If we only use 1 process for model parallelism, we can simply initialize the weight
if pgm.process_group_manager.tp_world_size == 1:
init_method(weight)
return None
# Initialize master weight
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
init_method(master_weight)
# Split the model into size of per_partition_size and take the corresponding partition
weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
return None
class ColumnParallelLinear(torch.nn.Module):
"""Column Parallel Linear layer
Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
Arguments:
in_features: first dimension of weight matrix W.
out_features: second dimension of weight matrix W.
bias: If true, add bias
init_method: method to initialize weights
gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
gather_output: bool = False,
) -> None:
super(ColumnParallelLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
self.gather_output = gather_output
# Allocate space for the weight and bias
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Initialize weight.
_initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.output_size_per_partition,
partition_dim = 0,
init_method = init_method,
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
input_parallel = copy_to_model_parallel_region(input_)
output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
if self.gather_output:
output = gather_from_model_parallel_region(output)
return output
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
- -
| W_1 |
| . |
W = | . | X = [X_1, ..., X_p]
| . |
| W_p |
- -
We assume that X is already parallelized. This is the case after ColumnParallelLinear.
This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
Arguments:
in_features: first dimension of matrix W.
out_features: second dimension of matrix W.
bias: If true, add bias
init_method: method to initialize weights.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Initialize weight.
_initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.input_size_per_partition,
partition_dim = 1,
init_method = init_method,
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
if self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
) -> None:
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self._weight = None
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
"""
Performs an embedding lookup for input tokens in the parallelized embedding layer
1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
"""
# Build the mask for out-of-vocabulary tokens.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings for the valid tokens.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Embedding of out-of-vocabulary tokens is set to 0.
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs to get the final output.
output = reduce_from_model_parallel_region(output_parallel)
return output

View File

@ -1,10 +1,21 @@
from functools import partial
from picotron.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
"""
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from picotron.tensor_parallel.tp_utils import VocabUtility
import torch
import math
import torch.nn.init as init
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from typing import Callable, Optional
import picotron.process_group_manager as pgm
from functools import partial
import torch.nn.init as init
from picotron.tensor_parallel.tp_communications import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
class TensorParallel():
def __init__(self, model, init_method = initialize_weight_tensor):
def __init__(self, model, init_method):
super().__init__()
module_linear_name_stype_mapping_list = [
@ -49,4 +60,247 @@ class TensorParallel():
embedding_dim=linear_layer.embedding_dim,
init_method=partial(self.init_method, vocab_embedding=True)
)
setattr(module, linear_proj_name, new_linear_layer)
setattr(module, linear_proj_name, new_linear_layer)
def initialize_weight_tensor(weight, vocab_embedding=False):
"""
Initialize the weight tensor with the default initialization method in PyTorch
If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features.
If it's a vocab embedding, it uses a normal distribution N(0, 1).
"""
if not vocab_embedding:
# Get the in_features from the shape of the weight tensor
_, in_features = weight.shape
# Calculate k and the uniform bounds
k = 1 / in_features
bound = math.sqrt(k)
# Initialize weights with U(-sqrt(k), sqrt(k))
torch.nn.init.uniform_(weight, -bound, bound)
else:
# Initialize Vocab embedding with N(0, 1)
torch.nn.init.normal_(weight, mean=0.0, std=1.0)
def _initialize_affine_weight(
weight: torch.Tensor,
out_features: int,
in_features: int,
per_partition_size: int,
partition_dim: int,
init_method: Callable[[torch.Tensor], torch.Tensor]
) -> Optional[torch.Tensor]:
"""
Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
Args:
weight: The weight tensor that will be initialized for the current partition.
out_features: second dimension of weight matrix W.
in_features: first dimension of weight matrix W.
per_partition_size: The size of the weight partition assigned to each process.
partition_dim: The dimension along which the weight matrix is split for parallelism.
init_method: The method used to initialize the weight values.
"""
# If we only use 1 process for model parallelism, we can simply initialize the weight
if pgm.process_group_manager.tp_world_size == 1:
init_method(weight)
return None
# Initialize master weight
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
init_method(master_weight)
# Split the model into size of per_partition_size and take the corresponding partition
weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
return None
class ColumnParallelLinear(torch.nn.Module):
"""Column Parallel Linear layer
Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
Arguments:
in_features: first dimension of weight matrix W.
out_features: second dimension of weight matrix W.
bias: If true, add bias
init_method: method to initialize weights
gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
gather_output: bool = False,
) -> None:
super(ColumnParallelLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
self.gather_output = gather_output
# Allocate space for the weight and bias
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Initialize weight.
_initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.output_size_per_partition,
partition_dim = 0,
init_method = init_method,
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
input_parallel = copy_to_model_parallel_region(input_)
output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
if self.gather_output:
output = gather_from_model_parallel_region(output)
return output
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
- -
| W_1 |
| . |
W = | . | X = [X_1, ..., X_p]
| . |
| W_p |
- -
We assume that X is already parallelized. This is the case after ColumnParallelLinear.
This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
Arguments:
in_features: first dimension of matrix W.
out_features: second dimension of matrix W.
bias: If true, add bias
init_method: method to initialize weights.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Initialize weight.
_initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.input_size_per_partition,
partition_dim = 1,
init_method = init_method,
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
if self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
) -> None:
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self._weight = None
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
"""
Performs an embedding lookup for input tokens in the parallelized embedding layer
1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
"""
# Build the mask for out-of-vocabulary tokens.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings for the valid tokens.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Embedding of out-of-vocabulary tokens is set to 0.
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs to get the final output.
output = reduce_from_model_parallel_region(output_parallel)
return output

View File

@ -2,7 +2,7 @@
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from picotron.tensor_parallel.utils import split_tensor_along_last_dim
from picotron.tensor_parallel.tp_utils import split_tensor_along_last_dim
import torch.distributed as dist
import torch
import picotron.process_group_manager as pgm

View File

@ -26,7 +26,7 @@ from picotron.utils import set_all_seed, print, to_readable_format, save_checkpo
from picotron.data import MicroBatchDataLoader
from picotron.process_group_manager import setup_process_group_manager
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from picotron.data_parallel.data_parallel_bucket import DataParallel
from picotron.data_parallel.data_parallel import DataParallelBucket
from picotron.model import Llama
import wandb
@ -203,7 +203,7 @@ if __name__ == "__main__":
# Context parallel and Data parallel both need gradient synchronization
if pgm.process_group_manager.cp_dp_world_size > 1:
model = DataParallel(model)
model = DataParallelBucket(model)
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
start_time = time.time()