match TP loss

This commit is contained in:
zzhhjjj 2024-10-27 02:22:05 +00:00
parent 51b5683dd3
commit e5cfb5240e
2 changed files with 91 additions and 353 deletions

View File

@ -2,8 +2,9 @@
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from src.parallel.tensor_parallel.utils import VocabUtility, divide_and_check_no_remainder
from src.parallel.tensor_parallel.utils import VocabUtility
import torch
import math
import torch.nn.init as init
import torch.nn.functional as F
from torch.nn.parameter import Parameter
@ -11,229 +12,25 @@ from typing import Callable, Optional
import src.distributed.process_group_manager as pgm
from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
# def _initialize_affine_weight(
# weight: torch.Tensor,
# out_features: int,
# in_features: int,
# per_partition_size: int,
# partition_dim: int,
# init_method: Callable[[torch.Tensor], torch.Tensor]
# ) -> Optional[torch.Tensor]:
# """
# Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
# Args:
# weight: The weight tensor that will be initialized for the current partition.
# out_features: second dimension of weight matrix W.
# in_features: first dimension of weight matrix W.
# per_partition_size: The size of the weight partition assigned to each process.
# partition_dim: The dimension along which the weight matrix is split for parallelism.
# init_method: The method used to initialize the weight values.
# """
# # If we only use 1 process for model parallelism, we can simply initialize the weight
# if pgm.process_group_manager.tp_world_size == 1:
# init_method(weight)
# return None
# # Initialize master weight
# master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
# init_method(master_weight)
# # Split the model into size of per_partition_size and take the corresponding partition
# weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
# weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
# return None
# class ColumnParallelLinear(torch.nn.Module):
# """Column Parallel Linear layer
# Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
# This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
# Arguments:
# in_features: first dimension of weight matrix W.
# out_features: second dimension of weight matrix W.
# bias: If true, add bias
# init_method: method to initialize weights
# gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
# """
# def __init__(
# self,
# in_features: int,
# out_features: int,
# bias: bool = False,
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
# gather_output: bool = False,
# ) -> None:
# super(ColumnParallelLinear, self).__init__()
# self.in_features = in_features
# self.out_features = out_features
# assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
# self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
# self.gather_output = gather_output
# # Allocate space for the weight and bias
# # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
# self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
# if bias:
# self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# # Always initialize bias to zero.
# with torch.no_grad():
# self.bias.zero_()
# else:
# self.register_parameter("bias", None)
# # Initialize weight.
# _initialize_affine_weight(
# self.weight,
# self.out_features,
# self.in_features,
# self.output_size_per_partition,
# partition_dim = 0,
# init_method = init_method,
# )
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
# input_parallel = copy_to_model_parallel_region(input_)
# output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
# if self.gather_output:
# output = gather_from_model_parallel_region(output)
# return output
# class RowParallelLinear(torch.nn.Module):
# """Linear layer with row parallelism.
# Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
# - -
# | W_1 |
# | . |
# W = | . | X = [X_1, ..., X_p]
# | . |
# | W_p |
# - -
# We assume that X is already parallelized. This is the case after ColumnParallelLinear.
# This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
# Arguments:
# in_features: first dimension of matrix W.
# out_features: second dimension of matrix W.
# bias: If true, add bias
# init_method: method to initialize weights.
# """
# def __init__(
# self,
# in_features: int,
# out_features: int,
# bias: bool = True,
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
# ):
# super(RowParallelLinear, self).__init__()
# # Keep input parameters
# self.in_features = in_features
# self.out_features = out_features
# self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
# self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
# if bias:
# self.bias = Parameter(torch.Tensor(self.out_features))
# # Always initialize bias to zero.
# with torch.no_grad():
# self.bias.zero_()
# else:
# self.register_parameter("bias", None)
# # Initialize weight.
# _initialize_affine_weight(
# self.weight,
# self.out_features,
# self.in_features,
# self.input_size_per_partition,
# partition_dim = 1,
# init_method = init_method,
# )
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
# output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
# # All-reduce across all the partitions.
# output_ = reduce_from_model_parallel_region(output_parallel)
# if self.bias is not None:
# output = output_ + self.bias
# else:
# output = output_
# return output
# class VocabParallelEmbedding(torch.nn.Module):
# """Embedding parallelized in the vocabulary dimension.
# This is mainly adapted from torch.nn.Embedding and all the default values are kept.
# Arguments:
# num_embeddings: vocabulary size.
# embedding_dim: size of hidden state.
# init_method: method to initialize weights.
# """
# def __init__(
# self,
# num_embeddings: int,
# embedding_dim: int,
# padding_idx: Optional[int] = None,
# max_norm: Optional[float] = None,
# norm_type: float = 2.0,
# scale_grad_by_freq: bool = False,
# sparse: bool = False,
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
# ) -> None:
# super(VocabParallelEmbedding, self).__init__()
# # Keep the input dimensions.
# self.num_embeddings = num_embeddings
# self.embedding_dim = embedding_dim
# self.padding_idx = padding_idx
# self.max_norm = max_norm
# self.norm_type = norm_type
# self.scale_grad_by_freq = scale_grad_by_freq
# self.sparse = sparse
# self._weight = None
# # Divide the weight matrix along the vocaburaly dimension.
# self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
# self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
# )
# self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# # Allocate weights.
# self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
# # And initialize.
# _initialize_affine_weight(
# self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
# )
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
# """
# Performs an embedding lookup for input tokens in the parallelized embedding layer
# 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
# 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
# 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
# """
# # Build the mask for out-of-vocabulary tokens.
# input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# # Mask the input.
# masked_input = input_.clone() - self.vocab_start_index
# masked_input[input_mask] = 0
# # Get the embeddings for the valid tokens.
# output_parallel = F.embedding(
# masked_input,
# self.weight,
# self.padding_idx,
# self.max_norm,
# self.norm_type,
# self.scale_grad_by_freq,
# self.sparse,
# )
# # Embedding of out-of-vocabulary tokens is set to 0.
# output_parallel[input_mask, :] = 0.0
# # Reduce across all the model parallel GPUs to get the final output.
# output = reduce_from_model_parallel_region(output_parallel)
# return output
def initialize_weight_tensor(weight, vocab_embedding=False):
"""
Initialize the weight tensor with the default initialization method in PyTorch
If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features.
If it's a vocab embedding, it uses a normal distribution N(0, 1).
"""
if not vocab_embedding:
# Get the in_features from the shape of the weight tensor
_, in_features = weight.shape
# Calculate k and the uniform bounds
k = 1 / in_features
bound = math.sqrt(k)
# Initialize weights with U(-sqrt(k), sqrt(k))
torch.nn.init.uniform_(weight, -bound, bound)
else:
# Initialize Vocab embedding with N(0, 1)
torch.nn.init.normal_(weight, mean=0.0, std=1.0)
def _initialize_affine_weight(
weight: torch.Tensor,
@ -241,64 +38,44 @@ def _initialize_affine_weight(
in_features: int,
per_partition_size: int,
partition_dim: int,
init_method: Callable[[torch.Tensor], torch.Tensor],
stride: int = 1,
return_master_weight: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor]
) -> Optional[torch.Tensor]:
"""Initialize affine weight for model parallel.
"""
Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
Args:
weight: The weight tensor that will be initialized for the current partition.
out_features: second dimension of weight matrix W.
in_features: first dimension of weight matrix W.
per_partition_size: The size of the weight partition assigned to each process.
partition_dim: The dimension along which the weight matrix is split for parallelism.
init_method: The method used to initialize the weight values.
"""
Build the master weight on all processes and scatter
the relevant chunk."""
# If we only use 1 process for model parallelism, bypass scatter.
world_size = pgm.process_group_manager.world_size
if world_size == 1:
# If we only use 1 process for model parallelism, we can simply initialize the weight
if pgm.process_group_manager.tp_world_size == 1:
init_method(weight)
if return_master_weight:
return weight
return None
# Initialize master weight
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
# init_method(master_weight)
init_method(master_weight)
k = 1.0 / in_features
bound = torch.sqrt(torch.tensor(k, dtype=master_weight.dtype))
# Split the model into size of per_partition_size and take the corresponding partition
weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
# Use PyTorch's built-in uniform initialization
init.uniform_(master_weight, -bound.item(), bound.item())
# Split and copy
per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
rank = pgm.process_group_manager.tp_rank
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
"""Column Parallel Linear layer
Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
in_features: first dimension of weight matrix W.
out_features: second dimension of weight matrix W.
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
init_method: method to initialize weights
gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
"""
def __init__(
@ -306,25 +83,20 @@ class ColumnParallelLinear(torch.nn.Module):
in_features: int,
out_features: int,
bias: bool = False,
gather_output: bool = True,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
gather_output: bool = False,
) -> None:
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = pgm.process_group_manager.tp_world_size
self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features))
# Allocate space for the weight and bias
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# Always initialize bias to zero.
@ -334,58 +106,39 @@ class ColumnParallelLinear(torch.nn.Module):
self.register_parameter("bias", None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
_initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
partition_dim = 0,
init_method = init_method,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Backprop: all-reduce.
def forward(self, input_: torch.Tensor) -> torch.Tensor:
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, self.bias)
output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
output = gather_from_model_parallel_region(output)
return output
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
- -
| A_1 |
| W_1 |
| . |
A = | . | X = [X_1, ..., X_p]
W = | . | X = [X_1, ..., X_p]
| . |
| A_p |
| W_p |
- -
We assume that X is already parallelized. This is the case after ColumnParallelLinear.
This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
in_features: first dimension of matrix W.
out_features: second dimension of matrix W.
bias: If true, add bias
init_method: method to initialize weights.
"""
def __init__(
@ -393,24 +146,15 @@ class RowParallelLinear(torch.nn.Module):
in_features: int,
out_features: int,
bias: bool = True,
input_is_parallel: bool = True, # Normally, input is parallelized, especially in Attention projection/MLP. There is a column parallel before it.
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = pgm.process_group_manager.tp_world_size
self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size)
self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features))
@ -421,40 +165,28 @@ class RowParallelLinear(torch.nn.Module):
self.register_parameter("bias", None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
_initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.input_size_per_partition,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
partition_dim = 1,
init_method = init_method,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce.
input_parallel = input_
# Matrix multiply
output_parallel = F.linear(input_parallel, self.weight)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
if self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output
return output
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
This is mainly adapted from torch.nn.Embedding and all the default values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
@ -495,13 +227,19 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Build the mask.
def forward(self, input_: torch.Tensor) -> torch.Tensor:
"""
Performs an embedding lookup for input tokens in the parallelized embedding layer
1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
"""
# Build the mask for out-of-vocabulary tokens.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings.
# Get the embeddings for the valid tokens.
output_parallel = F.embedding(
masked_input,
self.weight,
@ -511,9 +249,8 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq,
self.sparse,
)
# Mask the output embedding.
# Embedding of tokens that are not in the vocabulary is set to 0. do a all reduce at the end.
# Embedding of out-of-vocabulary tokens is set to 0.
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs to get the final output.
output = reduce_from_model_parallel_region(output_parallel)
return output

View File

@ -1,9 +1,10 @@
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from functools import partial
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
import torch.nn.init as init
import torch.nn as nn
class TensorParallel():
def __init__(self, model, init_method = init.xavier_normal_):
def __init__(self, model, init_method = initialize_weight_tensor):
super().__init__()
module_linear_name_stype_mapping_list = [
@ -46,6 +47,6 @@ class TensorParallel():
new_linear_layer = VocabParallelEmbedding(
num_embeddings=linear_layer.num_embeddings,
embedding_dim=linear_layer.embedding_dim,
init_method=self.init_method
init_method=partial(self.init_method, vocab_embedding=True)
)
setattr(module, linear_proj_name, new_linear_layer)