From 77e85fe49020b9934a385accc7b5b4b59f069c6b Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 16:26:11 +0000 Subject: [PATCH] split/merge into different files tp and dp --- picotron/data_parallel/data_parallel.py | 120 +++++++- .../data_parallel/data_parallel_bucket.py | 118 -------- picotron/tensor_parallel/layers.py | 256 ----------------- picotron/tensor_parallel/tensor_parallel.py | 264 +++++++++++++++++- .../{mappings.py => tp_communications.py} | 2 +- .../tensor_parallel/{utils.py => tp_utils.py} | 0 train.py | 4 +- 7 files changed, 379 insertions(+), 385 deletions(-) delete mode 100644 picotron/data_parallel/data_parallel_bucket.py delete mode 100644 picotron/tensor_parallel/layers.py rename picotron/tensor_parallel/{mappings.py => tp_communications.py} (97%) rename picotron/tensor_parallel/{utils.py => tp_utils.py} (100%) diff --git a/picotron/data_parallel/data_parallel.py b/picotron/data_parallel/data_parallel.py index 9bbbef6..aeedfa1 100644 --- a/picotron/data_parallel/data_parallel.py +++ b/picotron/data_parallel/data_parallel.py @@ -1,10 +1,13 @@ -import contextlib import torch import torch.distributed as dist +import contextlib from torch import nn +from torch.autograd import Variable + +from picotron.data_parallel.bucket import BucketManager import picotron.process_group_manager as pgm -class DataParallel(nn.Module): +class DataParallelNaive(nn.Module): def __init__(self, module): """ Initializes the DataParallel wrapper for a given module. @@ -49,4 +52,115 @@ class DataParallel(nn.Module): """ self.require_backward_grad_sync = False yield - self.require_backward_grad_sync = True \ No newline at end of file + self.require_backward_grad_sync = True + +class DataParallelBucket(nn.Module): + def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32): + """ + Initialize the DataParallelBucket module. + + Args: + module (nn.Module): The model to be parallelized. + process_group: The process group for gradient synchronization, which can be either + a data parallel group or a context parallel group. + bucket_cap_mb (int, optional): The maximum size of each gradient synchronization bucket in megabytes. + Defaults to 25 MB. + grad_type (torch.dtype, optional): The data type of gradients, defaulting to float32. + """ + super().__init__() + self.module = module + self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation + grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes + bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket + self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.cp_dp_group, bucket_size, grad_type) + self.register_backward_hook() + self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def backward(self, input_tensor, output_tensor, output_tensor_grad): + return self.module.backward(input_tensor, output_tensor, output_tensor_grad) + + def get_flops(self, *args, **kwargs): + return self.module.get_flops(*args, **kwargs) + + def register_backward_hook(self): + """ + Registers a backward hook to manually accumulate and synchronize gradients. + + This hook serves two main purposes: + 1. PyTorch does not natively support gradient accumulation with mixed precision. + 2. After gradient accumulation, it flags parameters as ready for synchronization. + + The gradient accumulation functions are stored to prevent them from going out of scope. + + References: + - https://github.com/NVIDIA/Megatron-LM/issues/690 + - https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html + - https://arxiv.org/abs/2006.15704 (page 5) + """ + self.grad_accs = [] + for param in self.module.parameters(): + if param.requires_grad: + # Expand so we get access to grad_fn. + param_tmp = param.expand_as(param) + # Get the gradient accumulator function. + grad_acc_fn = param_tmp.grad_fn.next_functions[0][0] + grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager)) + self.grad_accs.append(grad_acc_fn) + + def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager): + """ + Creates the a hook for each parameter to handle gradient accumulation and synchronization. + """ + def param_hook(*unused): + """ + The hook called after the gradient is ready. It performs the following: + 1. Accumulates the gradient into the main gradient. + 2. Adds a post-backward callback to wait for gradient synchronization completion. + 3. Marks the parameter as ready for synchronization. + """ + if param.requires_grad: + assert param.grad is not None + param.main_grad.add_(param.grad.data) # accumulate the gradients + param.grad = None + + # skip the gradient synchronization (gradient accumulation/PP micro batches) + if self.require_backward_grad_sync: + # Add a callback to wait for gradient synchronization. Ensures the callback is added only once. + # Callback is executed after the backward pass. It should be added per backward pass. + if not self._post_backward_callback_set: + Variable._execution_engine.queue_callback(self._post_backward) + self._post_backward_callback_set = True + + # mark the parameter as ready for gradient synchronization. + bucket_manager.mark_param_as_ready(param) + return param_hook + + @contextlib.contextmanager + def no_sync(self): + """A context manager to disable gradient synchronization.""" + self.require_backward_grad_sync = False + yield + self.require_backward_grad_sync = True + + def _post_backward(self): + """ + A post-backward callback that waits for gradient synchronization to finish, then copies + the synchronized gradients back to the parameters' grad attribute. + + This method is called after the backward pass and before the optimizer step. + """ + self.bucket_manager.wait() + self._post_backward_callback_set = False + # copy to params.grad so we can use the optimizer to update the parameters + for p in self.module.parameters(): + if p.requires_grad: + p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type. + + def reset(self): + """ + Reset the bucket manager and zero out gradients in the model + """ + self.bucket_manager.reset() \ No newline at end of file diff --git a/picotron/data_parallel/data_parallel_bucket.py b/picotron/data_parallel/data_parallel_bucket.py deleted file mode 100644 index 8f117b3..0000000 --- a/picotron/data_parallel/data_parallel_bucket.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch -import contextlib -from torch import nn -from torch.autograd import Variable - -from picotron.data_parallel.bucket import BucketManager -import picotron.process_group_manager as pgm - -class DataParallel(nn.Module): - def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32): - """ - Initialize the DataParallel module. - - Args: - module (nn.Module): The model to be parallelized. - process_group: The process group for gradient synchronization, which can be either - a data parallel group or a context parallel group. - bucket_cap_mb (int, optional): The maximum size of each gradient synchronization bucket in megabytes. - Defaults to 25 MB. - grad_type (torch.dtype, optional): The data type of gradients, defaulting to float32. - """ - super().__init__() - self.module = module - self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation - grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes - bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket - self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.cp_dp_group, bucket_size, grad_type) - self.register_backward_hook() - self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set - - def forward(self, *inputs, **kwargs): - return self.module(*inputs, **kwargs) - - def backward(self, input_tensor, output_tensor, output_tensor_grad): - return self.module.backward(input_tensor, output_tensor, output_tensor_grad) - - def get_flops(self, *args, **kwargs): - return self.module.get_flops(*args, **kwargs) - - def register_backward_hook(self): - """ - Registers a backward hook to manually accumulate and synchronize gradients. - - This hook serves two main purposes: - 1. PyTorch does not natively support gradient accumulation with mixed precision. - 2. After gradient accumulation, it flags parameters as ready for synchronization. - - The gradient accumulation functions are stored to prevent them from going out of scope. - - References: - - https://github.com/NVIDIA/Megatron-LM/issues/690 - - https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html - - https://arxiv.org/abs/2006.15704 (page 5) - """ - self.grad_accs = [] - for param in self.module.parameters(): - if param.requires_grad: - # Expand so we get access to grad_fn. - param_tmp = param.expand_as(param) - # Get the gradient accumulator function. - grad_acc_fn = param_tmp.grad_fn.next_functions[0][0] - grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager)) - self.grad_accs.append(grad_acc_fn) - - def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager): - """ - Creates the a hook for each parameter to handle gradient accumulation and synchronization. - """ - def param_hook(*unused): - """ - The hook called after the gradient is ready. It performs the following: - 1. Accumulates the gradient into the main gradient. - 2. Adds a post-backward callback to wait for gradient synchronization completion. - 3. Marks the parameter as ready for synchronization. - """ - if param.requires_grad: - assert param.grad is not None - param.main_grad.add_(param.grad.data) # accumulate the gradients - param.grad = None - - # skip the gradient synchronization (gradient accumulation/PP micro batches) - if self.require_backward_grad_sync: - # Add a callback to wait for gradient synchronization. Ensures the callback is added only once. - # Callback is executed after the backward pass. It should be added per backward pass. - if not self._post_backward_callback_set: - Variable._execution_engine.queue_callback(self._post_backward) - self._post_backward_callback_set = True - - # mark the parameter as ready for gradient synchronization. - bucket_manager.mark_param_as_ready(param) - return param_hook - - @contextlib.contextmanager - def no_sync(self): - """A context manager to disable gradient synchronization.""" - self.require_backward_grad_sync = False - yield - self.require_backward_grad_sync = True - - def _post_backward(self): - """ - A post-backward callback that waits for gradient synchronization to finish, then copies - the synchronized gradients back to the parameters' grad attribute. - - This method is called after the backward pass and before the optimizer step. - """ - self.bucket_manager.wait() - self._post_backward_callback_set = False - # copy to params.grad so we can use the optimizer to update the parameters - for p in self.module.parameters(): - if p.requires_grad: - p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type. - - def reset(self): - """ - Reset the bucket manager and zero out gradients in the model - """ - self.bucket_manager.reset() \ No newline at end of file diff --git a/picotron/tensor_parallel/layers.py b/picotron/tensor_parallel/layers.py deleted file mode 100644 index 9dd562d..0000000 --- a/picotron/tensor_parallel/layers.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -Inspired by Fair Scale/Megatron's Tensor Parallelism implementation -Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale -""" -from picotron.tensor_parallel.utils import VocabUtility -import torch -import math -import torch.nn.init as init -import torch.nn.functional as F -from torch.nn.parameter import Parameter -from typing import Callable, Optional -import picotron.process_group_manager as pgm -from picotron.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region - -def initialize_weight_tensor(weight, vocab_embedding=False): - """ - Initialize the weight tensor with the default initialization method in PyTorch - If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features. - If it's a vocab embedding, it uses a normal distribution N(0, 1). - """ - if not vocab_embedding: - # Get the in_features from the shape of the weight tensor - _, in_features = weight.shape - - # Calculate k and the uniform bounds - k = 1 / in_features - bound = math.sqrt(k) - - # Initialize weights with U(-sqrt(k), sqrt(k)) - torch.nn.init.uniform_(weight, -bound, bound) - else: - # Initialize Vocab embedding with N(0, 1) - torch.nn.init.normal_(weight, mean=0.0, std=1.0) - -def _initialize_affine_weight( - weight: torch.Tensor, - out_features: int, - in_features: int, - per_partition_size: int, - partition_dim: int, - init_method: Callable[[torch.Tensor], torch.Tensor] -) -> Optional[torch.Tensor]: - """ - Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight - Args: - weight: The weight tensor that will be initialized for the current partition. - out_features: second dimension of weight matrix W. - in_features: first dimension of weight matrix W. - per_partition_size: The size of the weight partition assigned to each process. - partition_dim: The dimension along which the weight matrix is split for parallelism. - init_method: The method used to initialize the weight values. - """ - - # If we only use 1 process for model parallelism, we can simply initialize the weight - if pgm.process_group_manager.tp_world_size == 1: - init_method(weight) - return None - - # Initialize master weight - master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False) - init_method(master_weight) - - # Split the model into size of per_partition_size and take the corresponding partition - weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim) - weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous() - - return None - -class ColumnParallelLinear(torch.nn.Module): - """Column Parallel Linear layer - Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p] - This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension. - Arguments: - in_features: first dimension of weight matrix W. - out_features: second dimension of weight matrix W. - bias: If true, add bias - init_method: method to initialize weights - gather_output: If true, gather the output from all the partitions. This is used for the last linear layer - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, - gather_output: bool = False, - ) -> None: - super(ColumnParallelLinear, self).__init__() - - self.in_features = in_features - self.out_features = out_features - assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size" - self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size - self.gather_output = gather_output - - # Allocate space for the weight and bias - # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions - self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i - if bias: - self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter("bias", None) - - # Initialize weight. - _initialize_affine_weight( - self.weight, - self.out_features, - self.in_features, - self.output_size_per_partition, - partition_dim = 0, - init_method = init_method, - ) - - def forward(self, input_: torch.Tensor) -> torch.Tensor: - input_parallel = copy_to_model_parallel_region(input_) - output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i - if self.gather_output: - output = gather_from_model_parallel_region(output) - return output - -class RowParallelLinear(torch.nn.Module): - """Linear layer with row parallelism. - Y = XW + b. W is parallelized along its first dimension and X along its second dimension as: - - - - | W_1 | - | . | - W = | . | X = [X_1, ..., X_p] - | . | - | W_p | - - - - We assume that X is already parallelized. This is the case after ColumnParallelLinear. - This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method. - Arguments: - in_features: first dimension of matrix W. - out_features: second dimension of matrix W. - bias: If true, add bias - init_method: method to initialize weights. - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, - ): - super(RowParallelLinear, self).__init__() - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size - - self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) - if bias: - self.bias = Parameter(torch.Tensor(self.out_features)) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter("bias", None) - - # Initialize weight. - _initialize_affine_weight( - self.weight, - self.out_features, - self.in_features, - self.input_size_per_partition, - partition_dim = 1, - init_method = init_method, - ) - - def forward(self, input_: torch.Tensor) -> torch.Tensor: - output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b - # All-reduce across all the partitions. - output_ = reduce_from_model_parallel_region(output_parallel) - if self.bias is not None: - output = output_ + self.bias - else: - output = output_ - return output - -class VocabParallelEmbedding(torch.nn.Module): - """Embedding parallelized in the vocabulary dimension. - This is mainly adapted from torch.nn.Embedding and all the default values are kept. - Arguments: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - init_method: method to initialize weights. - """ - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False, - init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, - ) -> None: - super(VocabParallelEmbedding, self).__init__() - # Keep the input dimensions. - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.max_norm = max_norm - self.norm_type = norm_type - self.scale_grad_by_freq = scale_grad_by_freq - self.sparse = sparse - self._weight = None - # Divide the weight matrix along the vocaburaly dimension. - self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size - ) - self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index - - # Allocate weights. - self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) - # And initialize. - _initialize_affine_weight( - self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method - ) - - def forward(self, input_: torch.Tensor) -> torch.Tensor: - """ - Performs an embedding lookup for input tokens in the parallelized embedding layer - 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input - 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero - 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization - """ - # Build the mask for out-of-vocabulary tokens. - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - # Get the embeddings for the valid tokens. - output_parallel = F.embedding( - masked_input, - self.weight, - self.padding_idx, - self.max_norm, - self.norm_type, - self.scale_grad_by_freq, - self.sparse, - ) - # Embedding of out-of-vocabulary tokens is set to 0. - output_parallel[input_mask, :] = 0.0 - # Reduce across all the model parallel GPUs to get the final output. - output = reduce_from_model_parallel_region(output_parallel) - return output \ No newline at end of file diff --git a/picotron/tensor_parallel/tensor_parallel.py b/picotron/tensor_parallel/tensor_parallel.py index 2f11c19..cd98288 100644 --- a/picotron/tensor_parallel/tensor_parallel.py +++ b/picotron/tensor_parallel/tensor_parallel.py @@ -1,10 +1,21 @@ -from functools import partial -from picotron.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor +""" +Inspired by Fair Scale/Megatron's Tensor Parallelism implementation +Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale +""" +from picotron.tensor_parallel.tp_utils import VocabUtility +import torch +import math import torch.nn.init as init -import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from typing import Callable, Optional +import picotron.process_group_manager as pgm +from functools import partial +import torch.nn.init as init +from picotron.tensor_parallel.tp_communications import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region class TensorParallel(): - def __init__(self, model, init_method = initialize_weight_tensor): + def __init__(self, model, init_method): super().__init__() module_linear_name_stype_mapping_list = [ @@ -49,4 +60,247 @@ class TensorParallel(): embedding_dim=linear_layer.embedding_dim, init_method=partial(self.init_method, vocab_embedding=True) ) - setattr(module, linear_proj_name, new_linear_layer) \ No newline at end of file + setattr(module, linear_proj_name, new_linear_layer) + +def initialize_weight_tensor(weight, vocab_embedding=False): + """ + Initialize the weight tensor with the default initialization method in PyTorch + If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features. + If it's a vocab embedding, it uses a normal distribution N(0, 1). + """ + if not vocab_embedding: + # Get the in_features from the shape of the weight tensor + _, in_features = weight.shape + + # Calculate k and the uniform bounds + k = 1 / in_features + bound = math.sqrt(k) + + # Initialize weights with U(-sqrt(k), sqrt(k)) + torch.nn.init.uniform_(weight, -bound, bound) + else: + # Initialize Vocab embedding with N(0, 1) + torch.nn.init.normal_(weight, mean=0.0, std=1.0) + +def _initialize_affine_weight( + weight: torch.Tensor, + out_features: int, + in_features: int, + per_partition_size: int, + partition_dim: int, + init_method: Callable[[torch.Tensor], torch.Tensor] +) -> Optional[torch.Tensor]: + """ + Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight + Args: + weight: The weight tensor that will be initialized for the current partition. + out_features: second dimension of weight matrix W. + in_features: first dimension of weight matrix W. + per_partition_size: The size of the weight partition assigned to each process. + partition_dim: The dimension along which the weight matrix is split for parallelism. + init_method: The method used to initialize the weight values. + """ + + # If we only use 1 process for model parallelism, we can simply initialize the weight + if pgm.process_group_manager.tp_world_size == 1: + init_method(weight) + return None + + # Initialize master weight + master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False) + init_method(master_weight) + + # Split the model into size of per_partition_size and take the corresponding partition + weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim) + weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous() + + return None + +class ColumnParallelLinear(torch.nn.Module): + """Column Parallel Linear layer + Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p] + This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension. + Arguments: + in_features: first dimension of weight matrix W. + out_features: second dimension of weight matrix W. + bias: If true, add bias + init_method: method to initialize weights + gather_output: If true, gather the output from all the partitions. This is used for the last linear layer + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, + gather_output: bool = False, + ) -> None: + super(ColumnParallelLinear, self).__init__() + + self.in_features = in_features + self.out_features = out_features + assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size" + self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size + self.gather_output = gather_output + + # Allocate space for the weight and bias + # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions + self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i + if bias: + self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter("bias", None) + + # Initialize weight. + _initialize_affine_weight( + self.weight, + self.out_features, + self.in_features, + self.output_size_per_partition, + partition_dim = 0, + init_method = init_method, + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + input_parallel = copy_to_model_parallel_region(input_) + output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i + if self.gather_output: + output = gather_from_model_parallel_region(output) + return output + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + Y = XW + b. W is parallelized along its first dimension and X along its second dimension as: + - - + | W_1 | + | . | + W = | . | X = [X_1, ..., X_p] + | . | + | W_p | + - - + We assume that X is already parallelized. This is the case after ColumnParallelLinear. + This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method. + Arguments: + in_features: first dimension of matrix W. + out_features: second dimension of matrix W. + bias: If true, add bias + init_method: method to initialize weights. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, + ): + super(RowParallelLinear, self).__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size + + self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) + if bias: + self.bias = Parameter(torch.Tensor(self.out_features)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter("bias", None) + + # Initialize weight. + _initialize_affine_weight( + self.weight, + self.out_features, + self.in_features, + self.input_size_per_partition, + partition_dim = 1, + init_method = init_method, + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b + # All-reduce across all the partitions. + output_ = reduce_from_model_parallel_region(output_parallel) + if self.bias is not None: + output = output_ + self.bias + else: + output = output_ + return output + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + This is mainly adapted from torch.nn.Embedding and all the default values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, + ) -> None: + super(VocabParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + self._weight = None + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index + + # Allocate weights. + self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) + # And initialize. + _initialize_affine_weight( + self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + """ + Performs an embedding lookup for input tokens in the parallelized embedding layer + 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input + 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero + 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization + """ + # Build the mask for out-of-vocabulary tokens. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + # Get the embeddings for the valid tokens. + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + # Embedding of out-of-vocabulary tokens is set to 0. + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs to get the final output. + output = reduce_from_model_parallel_region(output_parallel) + return output \ No newline at end of file diff --git a/picotron/tensor_parallel/mappings.py b/picotron/tensor_parallel/tp_communications.py similarity index 97% rename from picotron/tensor_parallel/mappings.py rename to picotron/tensor_parallel/tp_communications.py index 90cb356..1ac3dd2 100644 --- a/picotron/tensor_parallel/mappings.py +++ b/picotron/tensor_parallel/tp_communications.py @@ -2,7 +2,7 @@ Inspired by Fair Scale/Megatron's Tensor Parallelism implementation Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale """ -from picotron.tensor_parallel.utils import split_tensor_along_last_dim +from picotron.tensor_parallel.tp_utils import split_tensor_along_last_dim import torch.distributed as dist import torch import picotron.process_group_manager as pgm diff --git a/picotron/tensor_parallel/utils.py b/picotron/tensor_parallel/tp_utils.py similarity index 100% rename from picotron/tensor_parallel/utils.py rename to picotron/tensor_parallel/tp_utils.py diff --git a/train.py b/train.py index 7af0667..8842201 100644 --- a/train.py +++ b/train.py @@ -26,7 +26,7 @@ from picotron.utils import set_all_seed, print, to_readable_format, save_checkpo from picotron.data import MicroBatchDataLoader from picotron.process_group_manager import setup_process_group_manager from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel -from picotron.data_parallel.data_parallel_bucket import DataParallel +from picotron.data_parallel.data_parallel import DataParallelBucket from picotron.model import Llama import wandb @@ -203,7 +203,7 @@ if __name__ == "__main__": # Context parallel and Data parallel both need gradient synchronization if pgm.process_group_manager.cp_dp_world_size > 1: - model = DataParallel(model) + model = DataParallelBucket(model) print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank) start_time = time.time()