split/merge into different files tp and dp
This commit is contained in:
parent
db926026a6
commit
77e85fe490
@ -1,10 +1,13 @@
|
||||
import contextlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import contextlib
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
from picotron.data_parallel.bucket import BucketManager
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
class DataParallelNaive(nn.Module):
|
||||
def __init__(self, module):
|
||||
"""
|
||||
Initializes the DataParallel wrapper for a given module.
|
||||
@ -49,4 +52,115 @@ class DataParallel(nn.Module):
|
||||
"""
|
||||
self.require_backward_grad_sync = False
|
||||
yield
|
||||
self.require_backward_grad_sync = True
|
||||
self.require_backward_grad_sync = True
|
||||
|
||||
class DataParallelBucket(nn.Module):
|
||||
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
|
||||
"""
|
||||
Initialize the DataParallelBucket module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The model to be parallelized.
|
||||
process_group: The process group for gradient synchronization, which can be either
|
||||
a data parallel group or a context parallel group.
|
||||
bucket_cap_mb (int, optional): The maximum size of each gradient synchronization bucket in megabytes.
|
||||
Defaults to 25 MB.
|
||||
grad_type (torch.dtype, optional): The data type of gradients, defaulting to float32.
|
||||
"""
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
||||
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
|
||||
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
|
||||
self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.cp_dp_group, bucket_size, grad_type)
|
||||
self.register_backward_hook()
|
||||
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
|
||||
def get_flops(self, *args, **kwargs):
|
||||
return self.module.get_flops(*args, **kwargs)
|
||||
|
||||
def register_backward_hook(self):
|
||||
"""
|
||||
Registers a backward hook to manually accumulate and synchronize gradients.
|
||||
|
||||
This hook serves two main purposes:
|
||||
1. PyTorch does not natively support gradient accumulation with mixed precision.
|
||||
2. After gradient accumulation, it flags parameters as ready for synchronization.
|
||||
|
||||
The gradient accumulation functions are stored to prevent them from going out of scope.
|
||||
|
||||
References:
|
||||
- https://github.com/NVIDIA/Megatron-LM/issues/690
|
||||
- https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html
|
||||
- https://arxiv.org/abs/2006.15704 (page 5)
|
||||
"""
|
||||
self.grad_accs = []
|
||||
for param in self.module.parameters():
|
||||
if param.requires_grad:
|
||||
# Expand so we get access to grad_fn.
|
||||
param_tmp = param.expand_as(param)
|
||||
# Get the gradient accumulator function.
|
||||
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
|
||||
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
|
||||
self.grad_accs.append(grad_acc_fn)
|
||||
|
||||
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
|
||||
"""
|
||||
Creates the a hook for each parameter to handle gradient accumulation and synchronization.
|
||||
"""
|
||||
def param_hook(*unused):
|
||||
"""
|
||||
The hook called after the gradient is ready. It performs the following:
|
||||
1. Accumulates the gradient into the main gradient.
|
||||
2. Adds a post-backward callback to wait for gradient synchronization completion.
|
||||
3. Marks the parameter as ready for synchronization.
|
||||
"""
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
param.main_grad.add_(param.grad.data) # accumulate the gradients
|
||||
param.grad = None
|
||||
|
||||
# skip the gradient synchronization (gradient accumulation/PP micro batches)
|
||||
if self.require_backward_grad_sync:
|
||||
# Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
|
||||
# Callback is executed after the backward pass. It should be added per backward pass.
|
||||
if not self._post_backward_callback_set:
|
||||
Variable._execution_engine.queue_callback(self._post_backward)
|
||||
self._post_backward_callback_set = True
|
||||
|
||||
# mark the parameter as ready for gradient synchronization.
|
||||
bucket_manager.mark_param_as_ready(param)
|
||||
return param_hook
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_sync(self):
|
||||
"""A context manager to disable gradient synchronization."""
|
||||
self.require_backward_grad_sync = False
|
||||
yield
|
||||
self.require_backward_grad_sync = True
|
||||
|
||||
def _post_backward(self):
|
||||
"""
|
||||
A post-backward callback that waits for gradient synchronization to finish, then copies
|
||||
the synchronized gradients back to the parameters' grad attribute.
|
||||
|
||||
This method is called after the backward pass and before the optimizer step.
|
||||
"""
|
||||
self.bucket_manager.wait()
|
||||
self._post_backward_callback_set = False
|
||||
# copy to params.grad so we can use the optimizer to update the parameters
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type.
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the bucket manager and zero out gradients in the model
|
||||
"""
|
||||
self.bucket_manager.reset()
|
||||
@ -1,118 +0,0 @@
|
||||
import torch
|
||||
import contextlib
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
from picotron.data_parallel.bucket import BucketManager
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
|
||||
"""
|
||||
Initialize the DataParallel module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The model to be parallelized.
|
||||
process_group: The process group for gradient synchronization, which can be either
|
||||
a data parallel group or a context parallel group.
|
||||
bucket_cap_mb (int, optional): The maximum size of each gradient synchronization bucket in megabytes.
|
||||
Defaults to 25 MB.
|
||||
grad_type (torch.dtype, optional): The data type of gradients, defaulting to float32.
|
||||
"""
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
||||
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
|
||||
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
|
||||
self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.cp_dp_group, bucket_size, grad_type)
|
||||
self.register_backward_hook()
|
||||
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
|
||||
def get_flops(self, *args, **kwargs):
|
||||
return self.module.get_flops(*args, **kwargs)
|
||||
|
||||
def register_backward_hook(self):
|
||||
"""
|
||||
Registers a backward hook to manually accumulate and synchronize gradients.
|
||||
|
||||
This hook serves two main purposes:
|
||||
1. PyTorch does not natively support gradient accumulation with mixed precision.
|
||||
2. After gradient accumulation, it flags parameters as ready for synchronization.
|
||||
|
||||
The gradient accumulation functions are stored to prevent them from going out of scope.
|
||||
|
||||
References:
|
||||
- https://github.com/NVIDIA/Megatron-LM/issues/690
|
||||
- https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html
|
||||
- https://arxiv.org/abs/2006.15704 (page 5)
|
||||
"""
|
||||
self.grad_accs = []
|
||||
for param in self.module.parameters():
|
||||
if param.requires_grad:
|
||||
# Expand so we get access to grad_fn.
|
||||
param_tmp = param.expand_as(param)
|
||||
# Get the gradient accumulator function.
|
||||
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
|
||||
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
|
||||
self.grad_accs.append(grad_acc_fn)
|
||||
|
||||
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
|
||||
"""
|
||||
Creates the a hook for each parameter to handle gradient accumulation and synchronization.
|
||||
"""
|
||||
def param_hook(*unused):
|
||||
"""
|
||||
The hook called after the gradient is ready. It performs the following:
|
||||
1. Accumulates the gradient into the main gradient.
|
||||
2. Adds a post-backward callback to wait for gradient synchronization completion.
|
||||
3. Marks the parameter as ready for synchronization.
|
||||
"""
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
param.main_grad.add_(param.grad.data) # accumulate the gradients
|
||||
param.grad = None
|
||||
|
||||
# skip the gradient synchronization (gradient accumulation/PP micro batches)
|
||||
if self.require_backward_grad_sync:
|
||||
# Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
|
||||
# Callback is executed after the backward pass. It should be added per backward pass.
|
||||
if not self._post_backward_callback_set:
|
||||
Variable._execution_engine.queue_callback(self._post_backward)
|
||||
self._post_backward_callback_set = True
|
||||
|
||||
# mark the parameter as ready for gradient synchronization.
|
||||
bucket_manager.mark_param_as_ready(param)
|
||||
return param_hook
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_sync(self):
|
||||
"""A context manager to disable gradient synchronization."""
|
||||
self.require_backward_grad_sync = False
|
||||
yield
|
||||
self.require_backward_grad_sync = True
|
||||
|
||||
def _post_backward(self):
|
||||
"""
|
||||
A post-backward callback that waits for gradient synchronization to finish, then copies
|
||||
the synchronized gradients back to the parameters' grad attribute.
|
||||
|
||||
This method is called after the backward pass and before the optimizer step.
|
||||
"""
|
||||
self.bucket_manager.wait()
|
||||
self._post_backward_callback_set = False
|
||||
# copy to params.grad so we can use the optimizer to update the parameters
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type.
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the bucket manager and zero out gradients in the model
|
||||
"""
|
||||
self.bucket_manager.reset()
|
||||
@ -1,256 +0,0 @@
|
||||
"""
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from picotron.tensor_parallel.utils import VocabUtility
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Callable, Optional
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
|
||||
def initialize_weight_tensor(weight, vocab_embedding=False):
|
||||
"""
|
||||
Initialize the weight tensor with the default initialization method in PyTorch
|
||||
If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features.
|
||||
If it's a vocab embedding, it uses a normal distribution N(0, 1).
|
||||
"""
|
||||
if not vocab_embedding:
|
||||
# Get the in_features from the shape of the weight tensor
|
||||
_, in_features = weight.shape
|
||||
|
||||
# Calculate k and the uniform bounds
|
||||
k = 1 / in_features
|
||||
bound = math.sqrt(k)
|
||||
|
||||
# Initialize weights with U(-sqrt(k), sqrt(k))
|
||||
torch.nn.init.uniform_(weight, -bound, bound)
|
||||
else:
|
||||
# Initialize Vocab embedding with N(0, 1)
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=1.0)
|
||||
|
||||
def _initialize_affine_weight(
|
||||
weight: torch.Tensor,
|
||||
out_features: int,
|
||||
in_features: int,
|
||||
per_partition_size: int,
|
||||
partition_dim: int,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor]
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
|
||||
Args:
|
||||
weight: The weight tensor that will be initialized for the current partition.
|
||||
out_features: second dimension of weight matrix W.
|
||||
in_features: first dimension of weight matrix W.
|
||||
per_partition_size: The size of the weight partition assigned to each process.
|
||||
partition_dim: The dimension along which the weight matrix is split for parallelism.
|
||||
init_method: The method used to initialize the weight values.
|
||||
"""
|
||||
|
||||
# If we only use 1 process for model parallelism, we can simply initialize the weight
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
init_method(weight)
|
||||
return None
|
||||
|
||||
# Initialize master weight
|
||||
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
|
||||
init_method(master_weight)
|
||||
|
||||
# Split the model into size of per_partition_size and take the corresponding partition
|
||||
weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
|
||||
weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
|
||||
|
||||
return None
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Column Parallel Linear layer
|
||||
Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
|
||||
This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
|
||||
Arguments:
|
||||
in_features: first dimension of weight matrix W.
|
||||
out_features: second dimension of weight matrix W.
|
||||
bias: If true, add bias
|
||||
init_method: method to initialize weights
|
||||
gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
gather_output: bool = False,
|
||||
) -> None:
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
|
||||
self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Allocate space for the weight and bias
|
||||
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
|
||||
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Initialize weight.
|
||||
_initialize_affine_weight(
|
||||
self.weight,
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.output_size_per_partition,
|
||||
partition_dim = 0,
|
||||
init_method = init_method,
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
input_parallel = copy_to_model_parallel_region(input_)
|
||||
output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
|
||||
if self.gather_output:
|
||||
output = gather_from_model_parallel_region(output)
|
||||
return output
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| W_1 |
|
||||
| . |
|
||||
W = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| W_p |
|
||||
- -
|
||||
We assume that X is already parallelized. This is the case after ColumnParallelLinear.
|
||||
This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
|
||||
Arguments:
|
||||
in_features: first dimension of matrix W.
|
||||
out_features: second dimension of matrix W.
|
||||
bias: If true, add bias
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(self.out_features))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Initialize weight.
|
||||
_initialize_affine_weight(
|
||||
self.weight,
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.input_size_per_partition,
|
||||
partition_dim = 1,
|
||||
init_method = init_method,
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
|
||||
# All-reduce across all the partitions.
|
||||
output_ = reduce_from_model_parallel_region(output_parallel)
|
||||
if self.bias is not None:
|
||||
output = output_ + self.bias
|
||||
else:
|
||||
output = output_
|
||||
return output
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
This is mainly adapted from torch.nn.Embedding and all the default values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
) -> None:
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
self._weight = None
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
|
||||
)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
|
||||
|
||||
# Allocate weights.
|
||||
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
|
||||
# And initialize.
|
||||
_initialize_affine_weight(
|
||||
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Performs an embedding lookup for input tokens in the parallelized embedding layer
|
||||
1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
|
||||
2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
|
||||
3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
|
||||
"""
|
||||
# Build the mask for out-of-vocabulary tokens.
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
# Get the embeddings for the valid tokens.
|
||||
output_parallel = F.embedding(
|
||||
masked_input,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
# Embedding of out-of-vocabulary tokens is set to 0.
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs to get the final output.
|
||||
output = reduce_from_model_parallel_region(output_parallel)
|
||||
return output
|
||||
@ -1,10 +1,21 @@
|
||||
from functools import partial
|
||||
from picotron.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
|
||||
"""
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from picotron.tensor_parallel.tp_utils import VocabUtility
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.init as init
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Callable, Optional
|
||||
import picotron.process_group_manager as pgm
|
||||
from functools import partial
|
||||
import torch.nn.init as init
|
||||
from picotron.tensor_parallel.tp_communications import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
|
||||
class TensorParallel():
|
||||
def __init__(self, model, init_method = initialize_weight_tensor):
|
||||
def __init__(self, model, init_method):
|
||||
super().__init__()
|
||||
|
||||
module_linear_name_stype_mapping_list = [
|
||||
@ -49,4 +60,247 @@ class TensorParallel():
|
||||
embedding_dim=linear_layer.embedding_dim,
|
||||
init_method=partial(self.init_method, vocab_embedding=True)
|
||||
)
|
||||
setattr(module, linear_proj_name, new_linear_layer)
|
||||
setattr(module, linear_proj_name, new_linear_layer)
|
||||
|
||||
def initialize_weight_tensor(weight, vocab_embedding=False):
|
||||
"""
|
||||
Initialize the weight tensor with the default initialization method in PyTorch
|
||||
If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features.
|
||||
If it's a vocab embedding, it uses a normal distribution N(0, 1).
|
||||
"""
|
||||
if not vocab_embedding:
|
||||
# Get the in_features from the shape of the weight tensor
|
||||
_, in_features = weight.shape
|
||||
|
||||
# Calculate k and the uniform bounds
|
||||
k = 1 / in_features
|
||||
bound = math.sqrt(k)
|
||||
|
||||
# Initialize weights with U(-sqrt(k), sqrt(k))
|
||||
torch.nn.init.uniform_(weight, -bound, bound)
|
||||
else:
|
||||
# Initialize Vocab embedding with N(0, 1)
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=1.0)
|
||||
|
||||
def _initialize_affine_weight(
|
||||
weight: torch.Tensor,
|
||||
out_features: int,
|
||||
in_features: int,
|
||||
per_partition_size: int,
|
||||
partition_dim: int,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor]
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
|
||||
Args:
|
||||
weight: The weight tensor that will be initialized for the current partition.
|
||||
out_features: second dimension of weight matrix W.
|
||||
in_features: first dimension of weight matrix W.
|
||||
per_partition_size: The size of the weight partition assigned to each process.
|
||||
partition_dim: The dimension along which the weight matrix is split for parallelism.
|
||||
init_method: The method used to initialize the weight values.
|
||||
"""
|
||||
|
||||
# If we only use 1 process for model parallelism, we can simply initialize the weight
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
init_method(weight)
|
||||
return None
|
||||
|
||||
# Initialize master weight
|
||||
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
|
||||
init_method(master_weight)
|
||||
|
||||
# Split the model into size of per_partition_size and take the corresponding partition
|
||||
weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
|
||||
weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
|
||||
|
||||
return None
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Column Parallel Linear layer
|
||||
Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
|
||||
This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
|
||||
Arguments:
|
||||
in_features: first dimension of weight matrix W.
|
||||
out_features: second dimension of weight matrix W.
|
||||
bias: If true, add bias
|
||||
init_method: method to initialize weights
|
||||
gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
gather_output: bool = False,
|
||||
) -> None:
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
|
||||
self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Allocate space for the weight and bias
|
||||
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
|
||||
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Initialize weight.
|
||||
_initialize_affine_weight(
|
||||
self.weight,
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.output_size_per_partition,
|
||||
partition_dim = 0,
|
||||
init_method = init_method,
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
input_parallel = copy_to_model_parallel_region(input_)
|
||||
output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
|
||||
if self.gather_output:
|
||||
output = gather_from_model_parallel_region(output)
|
||||
return output
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| W_1 |
|
||||
| . |
|
||||
W = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| W_p |
|
||||
- -
|
||||
We assume that X is already parallelized. This is the case after ColumnParallelLinear.
|
||||
This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
|
||||
Arguments:
|
||||
in_features: first dimension of matrix W.
|
||||
out_features: second dimension of matrix W.
|
||||
bias: If true, add bias
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(self.out_features))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Initialize weight.
|
||||
_initialize_affine_weight(
|
||||
self.weight,
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.input_size_per_partition,
|
||||
partition_dim = 1,
|
||||
init_method = init_method,
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
|
||||
# All-reduce across all the partitions.
|
||||
output_ = reduce_from_model_parallel_region(output_parallel)
|
||||
if self.bias is not None:
|
||||
output = output_ + self.bias
|
||||
else:
|
||||
output = output_
|
||||
return output
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
This is mainly adapted from torch.nn.Embedding and all the default values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
) -> None:
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
self._weight = None
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
|
||||
)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
|
||||
|
||||
# Allocate weights.
|
||||
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
|
||||
# And initialize.
|
||||
_initialize_affine_weight(
|
||||
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Performs an embedding lookup for input tokens in the parallelized embedding layer
|
||||
1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
|
||||
2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
|
||||
3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
|
||||
"""
|
||||
# Build the mask for out-of-vocabulary tokens.
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
# Get the embeddings for the valid tokens.
|
||||
output_parallel = F.embedding(
|
||||
masked_input,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
# Embedding of out-of-vocabulary tokens is set to 0.
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs to get the final output.
|
||||
output = reduce_from_model_parallel_region(output_parallel)
|
||||
return output
|
||||
@ -2,7 +2,7 @@
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from picotron.tensor_parallel.utils import split_tensor_along_last_dim
|
||||
from picotron.tensor_parallel.tp_utils import split_tensor_along_last_dim
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import picotron.process_group_manager as pgm
|
||||
4
train.py
4
train.py
@ -26,7 +26,7 @@ from picotron.utils import set_all_seed, print, to_readable_format, save_checkpo
|
||||
from picotron.data import MicroBatchDataLoader
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from picotron.data_parallel.data_parallel_bucket import DataParallel
|
||||
from picotron.data_parallel.data_parallel import DataParallelBucket
|
||||
from picotron.model import Llama
|
||||
import wandb
|
||||
|
||||
@ -203,7 +203,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Context parallel and Data parallel both need gradient synchronization
|
||||
if pgm.process_group_manager.cp_dp_world_size > 1:
|
||||
model = DataParallel(model)
|
||||
model = DataParallelBucket(model)
|
||||
|
||||
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
start_time = time.time()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user