From e5cfb5240eb10f57a65e0460c0a8d5d66f0e16a8 Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Sun, 27 Oct 2024 02:22:05 +0000 Subject: [PATCH] match TP loss --- src/parallel/tensor_parallel/layers.py | 437 ++++-------------- .../tensor_parallel/tensor_parallel.py | 7 +- 2 files changed, 91 insertions(+), 353 deletions(-) diff --git a/src/parallel/tensor_parallel/layers.py b/src/parallel/tensor_parallel/layers.py index a892b45..148f60a 100644 --- a/src/parallel/tensor_parallel/layers.py +++ b/src/parallel/tensor_parallel/layers.py @@ -2,8 +2,9 @@ Inspired by Fair Scale/Megatron's Tensor Parallelism implementation Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale """ -from src.parallel.tensor_parallel.utils import VocabUtility, divide_and_check_no_remainder +from src.parallel.tensor_parallel.utils import VocabUtility import torch +import math import torch.nn.init as init import torch.nn.functional as F from torch.nn.parameter import Parameter @@ -11,229 +12,25 @@ from typing import Callable, Optional import src.distributed.process_group_manager as pgm from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region -# def _initialize_affine_weight( -# weight: torch.Tensor, -# out_features: int, -# in_features: int, -# per_partition_size: int, -# partition_dim: int, -# init_method: Callable[[torch.Tensor], torch.Tensor] -# ) -> Optional[torch.Tensor]: -# """ -# Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight -# Args: -# weight: The weight tensor that will be initialized for the current partition. -# out_features: second dimension of weight matrix W. -# in_features: first dimension of weight matrix W. -# per_partition_size: The size of the weight partition assigned to each process. -# partition_dim: The dimension along which the weight matrix is split for parallelism. -# init_method: The method used to initialize the weight values. -# """ - -# # If we only use 1 process for model parallelism, we can simply initialize the weight -# if pgm.process_group_manager.tp_world_size == 1: -# init_method(weight) -# return None - -# # Initialize master weight -# master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False) -# init_method(master_weight) - -# # Split the model into size of per_partition_size and take the corresponding partition -# weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim) -# weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous() - -# return None - -# class ColumnParallelLinear(torch.nn.Module): -# """Column Parallel Linear layer -# Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p] -# This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension. -# Arguments: -# in_features: first dimension of weight matrix W. -# out_features: second dimension of weight matrix W. -# bias: If true, add bias -# init_method: method to initialize weights -# gather_output: If true, gather the output from all the partitions. This is used for the last linear layer -# """ - -# def __init__( -# self, -# in_features: int, -# out_features: int, -# bias: bool = False, -# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, -# gather_output: bool = False, -# ) -> None: -# super(ColumnParallelLinear, self).__init__() - -# self.in_features = in_features -# self.out_features = out_features -# assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size" -# self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size -# self.gather_output = gather_output - -# # Allocate space for the weight and bias -# # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions -# self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i -# if bias: -# self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) -# # Always initialize bias to zero. -# with torch.no_grad(): -# self.bias.zero_() -# else: -# self.register_parameter("bias", None) - -# # Initialize weight. -# _initialize_affine_weight( -# self.weight, -# self.out_features, -# self.in_features, -# self.output_size_per_partition, -# partition_dim = 0, -# init_method = init_method, -# ) - -# def forward(self, input_: torch.Tensor) -> torch.Tensor: -# input_parallel = copy_to_model_parallel_region(input_) -# output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i -# if self.gather_output: -# output = gather_from_model_parallel_region(output) -# return output - -# class RowParallelLinear(torch.nn.Module): -# """Linear layer with row parallelism. -# Y = XW + b. W is parallelized along its first dimension and X along its second dimension as: -# - - -# | W_1 | -# | . | -# W = | . | X = [X_1, ..., X_p] -# | . | -# | W_p | -# - - -# We assume that X is already parallelized. This is the case after ColumnParallelLinear. -# This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method. -# Arguments: -# in_features: first dimension of matrix W. -# out_features: second dimension of matrix W. -# bias: If true, add bias -# init_method: method to initialize weights. -# """ - -# def __init__( -# self, -# in_features: int, -# out_features: int, -# bias: bool = True, -# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, -# ): -# super(RowParallelLinear, self).__init__() - -# # Keep input parameters -# self.in_features = in_features -# self.out_features = out_features -# self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size - -# self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) -# if bias: -# self.bias = Parameter(torch.Tensor(self.out_features)) -# # Always initialize bias to zero. -# with torch.no_grad(): -# self.bias.zero_() -# else: -# self.register_parameter("bias", None) - -# # Initialize weight. -# _initialize_affine_weight( -# self.weight, -# self.out_features, -# self.in_features, -# self.input_size_per_partition, -# partition_dim = 1, -# init_method = init_method, -# ) - -# def forward(self, input_: torch.Tensor) -> torch.Tensor: -# output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b -# # All-reduce across all the partitions. -# output_ = reduce_from_model_parallel_region(output_parallel) -# if self.bias is not None: -# output = output_ + self.bias -# else: -# output = output_ -# return output - -# class VocabParallelEmbedding(torch.nn.Module): -# """Embedding parallelized in the vocabulary dimension. -# This is mainly adapted from torch.nn.Embedding and all the default values are kept. -# Arguments: -# num_embeddings: vocabulary size. -# embedding_dim: size of hidden state. -# init_method: method to initialize weights. -# """ - -# def __init__( -# self, -# num_embeddings: int, -# embedding_dim: int, -# padding_idx: Optional[int] = None, -# max_norm: Optional[float] = None, -# norm_type: float = 2.0, -# scale_grad_by_freq: bool = False, -# sparse: bool = False, -# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, -# ) -> None: -# super(VocabParallelEmbedding, self).__init__() -# # Keep the input dimensions. -# self.num_embeddings = num_embeddings -# self.embedding_dim = embedding_dim -# self.padding_idx = padding_idx -# self.max_norm = max_norm -# self.norm_type = norm_type -# self.scale_grad_by_freq = scale_grad_by_freq -# self.sparse = sparse -# self._weight = None -# # Divide the weight matrix along the vocaburaly dimension. -# self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( -# self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size -# ) -# self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index - -# # Allocate weights. -# self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) -# # And initialize. -# _initialize_affine_weight( -# self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method -# ) - -# def forward(self, input_: torch.Tensor) -> torch.Tensor: -# """ -# Performs an embedding lookup for input tokens in the parallelized embedding layer -# 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input -# 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero -# 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization -# """ -# # Build the mask for out-of-vocabulary tokens. -# input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) -# # Mask the input. -# masked_input = input_.clone() - self.vocab_start_index -# masked_input[input_mask] = 0 -# # Get the embeddings for the valid tokens. -# output_parallel = F.embedding( -# masked_input, -# self.weight, -# self.padding_idx, -# self.max_norm, -# self.norm_type, -# self.scale_grad_by_freq, -# self.sparse, -# ) -# # Embedding of out-of-vocabulary tokens is set to 0. -# output_parallel[input_mask, :] = 0.0 -# # Reduce across all the model parallel GPUs to get the final output. -# output = reduce_from_model_parallel_region(output_parallel) -# return output - +def initialize_weight_tensor(weight, vocab_embedding=False): + """ + Initialize the weight tensor with the default initialization method in PyTorch + If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features. + If it's a vocab embedding, it uses a normal distribution N(0, 1). + """ + if not vocab_embedding: + # Get the in_features from the shape of the weight tensor + _, in_features = weight.shape + + # Calculate k and the uniform bounds + k = 1 / in_features + bound = math.sqrt(k) + + # Initialize weights with U(-sqrt(k), sqrt(k)) + torch.nn.init.uniform_(weight, -bound, bound) + else: + # Initialize Vocab embedding with N(0, 1) + torch.nn.init.normal_(weight, mean=0.0, std=1.0) def _initialize_affine_weight( weight: torch.Tensor, @@ -241,64 +38,44 @@ def _initialize_affine_weight( in_features: int, per_partition_size: int, partition_dim: int, - init_method: Callable[[torch.Tensor], torch.Tensor], - stride: int = 1, - return_master_weight: bool = False, + init_method: Callable[[torch.Tensor], torch.Tensor] ) -> Optional[torch.Tensor]: - """Initialize affine weight for model parallel. + """ + Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight + Args: + weight: The weight tensor that will be initialized for the current partition. + out_features: second dimension of weight matrix W. + in_features: first dimension of weight matrix W. + per_partition_size: The size of the weight partition assigned to each process. + partition_dim: The dimension along which the weight matrix is split for parallelism. + init_method: The method used to initialize the weight values. + """ - Build the master weight on all processes and scatter - the relevant chunk.""" - - # If we only use 1 process for model parallelism, bypass scatter. - world_size = pgm.process_group_manager.world_size - if world_size == 1: + # If we only use 1 process for model parallelism, we can simply initialize the weight + if pgm.process_group_manager.tp_world_size == 1: init_method(weight) - if return_master_weight: - return weight return None # Initialize master weight master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False) - # init_method(master_weight) + init_method(master_weight) - k = 1.0 / in_features - bound = torch.sqrt(torch.tensor(k, dtype=master_weight.dtype)) + # Split the model into size of per_partition_size and take the corresponding partition + weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim) + weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous() - # Use PyTorch's built-in uniform initialization - init.uniform_(master_weight, -bound.item(), bound.item()) - - # Split and copy - per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride) - weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) - rank = pgm.process_group_manager.tp_rank - my_weight_list = weight_list[rank::world_size] - - with torch.no_grad(): - torch.cat(my_weight_list, dim=partition_dim, out=weight) - if return_master_weight: - return master_weight return None class ColumnParallelLinear(torch.nn.Module): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - + """Column Parallel Linear layer + Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p] + This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension. Arguments: - in_features: first dimension of matrix A. - out_features: second dimension of matrix A. + in_features: first dimension of weight matrix W. + out_features: second dimension of weight matrix W. bias: If true, add bias - gather_output: If true, call all-gather on output and make Y avaiable - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. + init_method: method to initialize weights + gather_output: If true, gather the output from all the partitions. This is used for the last linear layer """ def __init__( @@ -306,25 +83,20 @@ class ColumnParallelLinear(torch.nn.Module): in_features: int, out_features: int, bias: bool = False, - gather_output: bool = True, init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, - stride: int = 1, - keep_master_weight_for_test: bool = False, + gather_output: bool = False, ) -> None: super(ColumnParallelLinear, self).__init__() - # Keep input parameters self.in_features = in_features self.out_features = out_features + assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size" + self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - world_size = pgm.process_group_manager.tp_world_size - self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size) - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) + # Allocate space for the weight and bias + # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions + self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i if bias: self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) # Always initialize bias to zero. @@ -334,58 +106,39 @@ class ColumnParallelLinear(torch.nn.Module): self.register_parameter("bias", None) # Initialize weight. - self.master_weight = _initialize_affine_weight( + _initialize_affine_weight( self.weight, self.out_features, self.in_features, self.output_size_per_partition, - 0, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, + partition_dim = 0, + init_method = init_method, ) - def get_master_weight(self) -> torch.Tensor: - return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1) - - def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore - # Backprop: all-reduce. + def forward(self, input_: torch.Tensor) -> torch.Tensor: input_parallel = copy_to_model_parallel_region(input_) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight, self.bias) + output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i if self.gather_output: - # All-gather across the partitions. - output = gather_from_model_parallel_region(output_parallel) - else: - output = output_parallel + output = gather_from_model_parallel_region(output) return output - - + class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: + Y = XW + b. W is parallelized along its first dimension and X along its second dimension as: - - - | A_1 | + | W_1 | | . | - A = | . | X = [X_1, ..., X_p] + W = | . | X = [X_1, ..., X_p] | . | - | A_p | + | W_p | - - + We assume that X is already parallelized. This is the case after ColumnParallelLinear. + This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method. Arguments: - in_features: first dimension of matrix A. - out_features: second dimension of matrix A. - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. + in_features: first dimension of matrix W. + out_features: second dimension of matrix W. + bias: If true, add bias + init_method: method to initialize weights. """ def __init__( @@ -393,24 +146,15 @@ class RowParallelLinear(torch.nn.Module): in_features: int, out_features: int, bias: bool = True, - input_is_parallel: bool = True, # Normally, input is parallelized, especially in Attention projection/MLP. There is a column parallel before it. init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, - stride: int = 1, - keep_master_weight_for_test: bool = False, ): super(RowParallelLinear, self).__init__() # Keep input parameters self.in_features = in_features self.out_features = out_features - self.input_is_parallel = input_is_parallel - # Divide the weight matrix along the last dimension. - world_size = pgm.process_group_manager.tp_world_size - self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size) + self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) if bias: self.bias = Parameter(torch.Tensor(self.out_features)) @@ -421,40 +165,28 @@ class RowParallelLinear(torch.nn.Module): self.register_parameter("bias", None) # Initialize weight. - self.master_weight = _initialize_affine_weight( + _initialize_affine_weight( self.weight, self.out_features, self.in_features, self.input_size_per_partition, - 1, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, + partition_dim = 1, + init_method = init_method, ) - def get_master_weight(self) -> torch.Tensor: - return gather_from_model_parallel_region(self.weight.data) - - def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore - # Set up backprop all-reduce. - - input_parallel = input_ - - # Matrix multiply - output_parallel = F.linear(input_parallel, self.weight) + def forward(self, input_: torch.Tensor) -> torch.Tensor: + output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b # All-reduce across all the partitions. output_ = reduce_from_model_parallel_region(output_parallel) if self.bias is not None: output = output_ + self.bias else: output = output_ - return output + return output class VocabParallelEmbedding(torch.nn.Module): """Embedding parallelized in the vocabulary dimension. - - This is mainly adapted from torch.nn.Embedding and all the default - values are kept. + This is mainly adapted from torch.nn.Embedding and all the default values are kept. Arguments: num_embeddings: vocabulary size. embedding_dim: size of hidden state. @@ -495,13 +227,19 @@ class VocabParallelEmbedding(torch.nn.Module): self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method ) - def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore - # Build the mask. + def forward(self, input_: torch.Tensor) -> torch.Tensor: + """ + Performs an embedding lookup for input tokens in the parallelized embedding layer + 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input + 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero + 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization + """ + # Build the mask for out-of-vocabulary tokens. input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - # Get the embeddings. + # Get the embeddings for the valid tokens. output_parallel = F.embedding( masked_input, self.weight, @@ -511,9 +249,8 @@ class VocabParallelEmbedding(torch.nn.Module): self.scale_grad_by_freq, self.sparse, ) - # Mask the output embedding. - # Embedding of tokens that are not in the vocabulary is set to 0. do a all reduce at the end. + # Embedding of out-of-vocabulary tokens is set to 0. output_parallel[input_mask, :] = 0.0 - # Reduce across all the model parallel GPUs. + # Reduce across all the model parallel GPUs to get the final output. output = reduce_from_model_parallel_region(output_parallel) return output \ No newline at end of file diff --git a/src/parallel/tensor_parallel/tensor_parallel.py b/src/parallel/tensor_parallel/tensor_parallel.py index 3e2c59b..d6712d2 100644 --- a/src/parallel/tensor_parallel/tensor_parallel.py +++ b/src/parallel/tensor_parallel/tensor_parallel.py @@ -1,9 +1,10 @@ -from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +from functools import partial +from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor import torch.nn.init as init import torch.nn as nn class TensorParallel(): - def __init__(self, model, init_method = init.xavier_normal_): + def __init__(self, model, init_method = initialize_weight_tensor): super().__init__() module_linear_name_stype_mapping_list = [ @@ -46,6 +47,6 @@ class TensorParallel(): new_linear_layer = VocabParallelEmbedding( num_embeddings=linear_layer.num_embeddings, embedding_dim=linear_layer.embedding_dim, - init_method=self.init_method + init_method=partial(self.init_method, vocab_embedding=True) ) setattr(module, linear_proj_name, new_linear_layer) \ No newline at end of file