match TP loss
This commit is contained in:
parent
51b5683dd3
commit
e5cfb5240e
@ -2,8 +2,9 @@
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from src.parallel.tensor_parallel.utils import VocabUtility, divide_and_check_no_remainder
|
||||
from src.parallel.tensor_parallel.utils import VocabUtility
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -11,229 +12,25 @@ from typing import Callable, Optional
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
|
||||
# def _initialize_affine_weight(
|
||||
# weight: torch.Tensor,
|
||||
# out_features: int,
|
||||
# in_features: int,
|
||||
# per_partition_size: int,
|
||||
# partition_dim: int,
|
||||
# init_method: Callable[[torch.Tensor], torch.Tensor]
|
||||
# ) -> Optional[torch.Tensor]:
|
||||
# """
|
||||
# Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
|
||||
# Args:
|
||||
# weight: The weight tensor that will be initialized for the current partition.
|
||||
# out_features: second dimension of weight matrix W.
|
||||
# in_features: first dimension of weight matrix W.
|
||||
# per_partition_size: The size of the weight partition assigned to each process.
|
||||
# partition_dim: The dimension along which the weight matrix is split for parallelism.
|
||||
# init_method: The method used to initialize the weight values.
|
||||
# """
|
||||
|
||||
# # If we only use 1 process for model parallelism, we can simply initialize the weight
|
||||
# if pgm.process_group_manager.tp_world_size == 1:
|
||||
# init_method(weight)
|
||||
# return None
|
||||
|
||||
# # Initialize master weight
|
||||
# master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
|
||||
# init_method(master_weight)
|
||||
|
||||
# # Split the model into size of per_partition_size and take the corresponding partition
|
||||
# weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
|
||||
# weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
|
||||
|
||||
# return None
|
||||
|
||||
# class ColumnParallelLinear(torch.nn.Module):
|
||||
# """Column Parallel Linear layer
|
||||
# Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
|
||||
# This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
|
||||
# Arguments:
|
||||
# in_features: first dimension of weight matrix W.
|
||||
# out_features: second dimension of weight matrix W.
|
||||
# bias: If true, add bias
|
||||
# init_method: method to initialize weights
|
||||
# gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_features: int,
|
||||
# out_features: int,
|
||||
# bias: bool = False,
|
||||
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
# gather_output: bool = False,
|
||||
# ) -> None:
|
||||
# super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
# self.in_features = in_features
|
||||
# self.out_features = out_features
|
||||
# assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
|
||||
# self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
|
||||
# self.gather_output = gather_output
|
||||
|
||||
# # Allocate space for the weight and bias
|
||||
# # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
|
||||
# self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
|
||||
# if bias:
|
||||
# self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
|
||||
# # Always initialize bias to zero.
|
||||
# with torch.no_grad():
|
||||
# self.bias.zero_()
|
||||
# else:
|
||||
# self.register_parameter("bias", None)
|
||||
|
||||
# # Initialize weight.
|
||||
# _initialize_affine_weight(
|
||||
# self.weight,
|
||||
# self.out_features,
|
||||
# self.in_features,
|
||||
# self.output_size_per_partition,
|
||||
# partition_dim = 0,
|
||||
# init_method = init_method,
|
||||
# )
|
||||
|
||||
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# input_parallel = copy_to_model_parallel_region(input_)
|
||||
# output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
|
||||
# if self.gather_output:
|
||||
# output = gather_from_model_parallel_region(output)
|
||||
# return output
|
||||
|
||||
# class RowParallelLinear(torch.nn.Module):
|
||||
# """Linear layer with row parallelism.
|
||||
# Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
|
||||
# - -
|
||||
# | W_1 |
|
||||
# | . |
|
||||
# W = | . | X = [X_1, ..., X_p]
|
||||
# | . |
|
||||
# | W_p |
|
||||
# - -
|
||||
# We assume that X is already parallelized. This is the case after ColumnParallelLinear.
|
||||
# This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
|
||||
# Arguments:
|
||||
# in_features: first dimension of matrix W.
|
||||
# out_features: second dimension of matrix W.
|
||||
# bias: If true, add bias
|
||||
# init_method: method to initialize weights.
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_features: int,
|
||||
# out_features: int,
|
||||
# bias: bool = True,
|
||||
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
# ):
|
||||
# super(RowParallelLinear, self).__init__()
|
||||
|
||||
# # Keep input parameters
|
||||
# self.in_features = in_features
|
||||
# self.out_features = out_features
|
||||
# self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
|
||||
|
||||
# self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
|
||||
# if bias:
|
||||
# self.bias = Parameter(torch.Tensor(self.out_features))
|
||||
# # Always initialize bias to zero.
|
||||
# with torch.no_grad():
|
||||
# self.bias.zero_()
|
||||
# else:
|
||||
# self.register_parameter("bias", None)
|
||||
|
||||
# # Initialize weight.
|
||||
# _initialize_affine_weight(
|
||||
# self.weight,
|
||||
# self.out_features,
|
||||
# self.in_features,
|
||||
# self.input_size_per_partition,
|
||||
# partition_dim = 1,
|
||||
# init_method = init_method,
|
||||
# )
|
||||
|
||||
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
|
||||
# # All-reduce across all the partitions.
|
||||
# output_ = reduce_from_model_parallel_region(output_parallel)
|
||||
# if self.bias is not None:
|
||||
# output = output_ + self.bias
|
||||
# else:
|
||||
# output = output_
|
||||
# return output
|
||||
|
||||
# class VocabParallelEmbedding(torch.nn.Module):
|
||||
# """Embedding parallelized in the vocabulary dimension.
|
||||
# This is mainly adapted from torch.nn.Embedding and all the default values are kept.
|
||||
# Arguments:
|
||||
# num_embeddings: vocabulary size.
|
||||
# embedding_dim: size of hidden state.
|
||||
# init_method: method to initialize weights.
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# num_embeddings: int,
|
||||
# embedding_dim: int,
|
||||
# padding_idx: Optional[int] = None,
|
||||
# max_norm: Optional[float] = None,
|
||||
# norm_type: float = 2.0,
|
||||
# scale_grad_by_freq: bool = False,
|
||||
# sparse: bool = False,
|
||||
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
# ) -> None:
|
||||
# super(VocabParallelEmbedding, self).__init__()
|
||||
# # Keep the input dimensions.
|
||||
# self.num_embeddings = num_embeddings
|
||||
# self.embedding_dim = embedding_dim
|
||||
# self.padding_idx = padding_idx
|
||||
# self.max_norm = max_norm
|
||||
# self.norm_type = norm_type
|
||||
# self.scale_grad_by_freq = scale_grad_by_freq
|
||||
# self.sparse = sparse
|
||||
# self._weight = None
|
||||
# # Divide the weight matrix along the vocaburaly dimension.
|
||||
# self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
|
||||
# self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
|
||||
# )
|
||||
# self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
|
||||
|
||||
# # Allocate weights.
|
||||
# self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
|
||||
# # And initialize.
|
||||
# _initialize_affine_weight(
|
||||
# self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
|
||||
# )
|
||||
|
||||
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# """
|
||||
# Performs an embedding lookup for input tokens in the parallelized embedding layer
|
||||
# 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
|
||||
# 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
|
||||
# 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
|
||||
# """
|
||||
# # Build the mask for out-of-vocabulary tokens.
|
||||
# input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
# # Mask the input.
|
||||
# masked_input = input_.clone() - self.vocab_start_index
|
||||
# masked_input[input_mask] = 0
|
||||
# # Get the embeddings for the valid tokens.
|
||||
# output_parallel = F.embedding(
|
||||
# masked_input,
|
||||
# self.weight,
|
||||
# self.padding_idx,
|
||||
# self.max_norm,
|
||||
# self.norm_type,
|
||||
# self.scale_grad_by_freq,
|
||||
# self.sparse,
|
||||
# )
|
||||
# # Embedding of out-of-vocabulary tokens is set to 0.
|
||||
# output_parallel[input_mask, :] = 0.0
|
||||
# # Reduce across all the model parallel GPUs to get the final output.
|
||||
# output = reduce_from_model_parallel_region(output_parallel)
|
||||
# return output
|
||||
|
||||
def initialize_weight_tensor(weight, vocab_embedding=False):
|
||||
"""
|
||||
Initialize the weight tensor with the default initialization method in PyTorch
|
||||
If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features.
|
||||
If it's a vocab embedding, it uses a normal distribution N(0, 1).
|
||||
"""
|
||||
if not vocab_embedding:
|
||||
# Get the in_features from the shape of the weight tensor
|
||||
_, in_features = weight.shape
|
||||
|
||||
# Calculate k and the uniform bounds
|
||||
k = 1 / in_features
|
||||
bound = math.sqrt(k)
|
||||
|
||||
# Initialize weights with U(-sqrt(k), sqrt(k))
|
||||
torch.nn.init.uniform_(weight, -bound, bound)
|
||||
else:
|
||||
# Initialize Vocab embedding with N(0, 1)
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=1.0)
|
||||
|
||||
def _initialize_affine_weight(
|
||||
weight: torch.Tensor,
|
||||
@ -241,64 +38,44 @@ def _initialize_affine_weight(
|
||||
in_features: int,
|
||||
per_partition_size: int,
|
||||
partition_dim: int,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor],
|
||||
stride: int = 1,
|
||||
return_master_weight: bool = False,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor]
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Initialize affine weight for model parallel.
|
||||
"""
|
||||
Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
|
||||
Args:
|
||||
weight: The weight tensor that will be initialized for the current partition.
|
||||
out_features: second dimension of weight matrix W.
|
||||
in_features: first dimension of weight matrix W.
|
||||
per_partition_size: The size of the weight partition assigned to each process.
|
||||
partition_dim: The dimension along which the weight matrix is split for parallelism.
|
||||
init_method: The method used to initialize the weight values.
|
||||
"""
|
||||
|
||||
Build the master weight on all processes and scatter
|
||||
the relevant chunk."""
|
||||
|
||||
# If we only use 1 process for model parallelism, bypass scatter.
|
||||
world_size = pgm.process_group_manager.world_size
|
||||
if world_size == 1:
|
||||
# If we only use 1 process for model parallelism, we can simply initialize the weight
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
init_method(weight)
|
||||
if return_master_weight:
|
||||
return weight
|
||||
return None
|
||||
|
||||
# Initialize master weight
|
||||
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
|
||||
# init_method(master_weight)
|
||||
init_method(master_weight)
|
||||
|
||||
k = 1.0 / in_features
|
||||
bound = torch.sqrt(torch.tensor(k, dtype=master_weight.dtype))
|
||||
# Split the model into size of per_partition_size and take the corresponding partition
|
||||
weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
|
||||
weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
|
||||
|
||||
# Use PyTorch's built-in uniform initialization
|
||||
init.uniform_(master_weight, -bound.item(), bound.item())
|
||||
|
||||
# Split and copy
|
||||
per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride)
|
||||
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
|
||||
rank = pgm.process_group_manager.tp_rank
|
||||
my_weight_list = weight_list[rank::world_size]
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
||||
if return_master_weight:
|
||||
return master_weight
|
||||
return None
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
"""Column Parallel Linear layer
|
||||
Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
|
||||
This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
|
||||
Arguments:
|
||||
in_features: first dimension of matrix A.
|
||||
out_features: second dimension of matrix A.
|
||||
in_features: first dimension of weight matrix W.
|
||||
out_features: second dimension of weight matrix W.
|
||||
bias: If true, add bias
|
||||
gather_output: If true, call all-gather on output and make Y avaiable
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
init_method: method to initialize weights
|
||||
gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -306,25 +83,20 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
gather_output: bool = True,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
stride: int = 1,
|
||||
keep_master_weight_for_test: bool = False,
|
||||
gather_output: bool = False,
|
||||
) -> None:
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
|
||||
self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = pgm.process_group_manager.tp_world_size
|
||||
self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size)
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features))
|
||||
# Allocate space for the weight and bias
|
||||
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
|
||||
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
|
||||
# Always initialize bias to zero.
|
||||
@ -334,58 +106,39 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Initialize weight.
|
||||
self.master_weight = _initialize_affine_weight(
|
||||
_initialize_affine_weight(
|
||||
self.weight,
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.output_size_per_partition,
|
||||
0,
|
||||
init_method,
|
||||
stride=stride,
|
||||
return_master_weight=keep_master_weight_for_test,
|
||||
partition_dim = 0,
|
||||
init_method = init_method,
|
||||
)
|
||||
|
||||
def get_master_weight(self) -> torch.Tensor:
|
||||
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
# Backprop: all-reduce.
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
input_parallel = copy_to_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = F.linear(input_parallel, self.weight, self.bias)
|
||||
output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_from_model_parallel_region(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output = gather_from_model_parallel_region(output)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| W_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
W = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
| W_p |
|
||||
- -
|
||||
We assume that X is already parallelized. This is the case after ColumnParallelLinear.
|
||||
This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
|
||||
Arguments:
|
||||
in_features: first dimension of matrix A.
|
||||
out_features: second dimension of matrix A.
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
in_features: first dimension of matrix W.
|
||||
out_features: second dimension of matrix W.
|
||||
bias: If true, add bias
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -393,24 +146,15 @@ class RowParallelLinear(torch.nn.Module):
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True, # Normally, input is parallelized, especially in Attention projection/MLP. There is a column parallel before it.
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
stride: int = 1,
|
||||
keep_master_weight_for_test: bool = False,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.input_is_parallel = input_is_parallel
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = pgm.process_group_manager.tp_world_size
|
||||
self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size)
|
||||
self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(self.out_features))
|
||||
@ -421,40 +165,28 @@ class RowParallelLinear(torch.nn.Module):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Initialize weight.
|
||||
self.master_weight = _initialize_affine_weight(
|
||||
_initialize_affine_weight(
|
||||
self.weight,
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.input_size_per_partition,
|
||||
1,
|
||||
init_method,
|
||||
stride=stride,
|
||||
return_master_weight=keep_master_weight_for_test,
|
||||
partition_dim = 1,
|
||||
init_method = init_method,
|
||||
)
|
||||
|
||||
def get_master_weight(self) -> torch.Tensor:
|
||||
return gather_from_model_parallel_region(self.weight.data)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
|
||||
# Set up backprop all-reduce.
|
||||
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply
|
||||
output_parallel = F.linear(input_parallel, self.weight)
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
|
||||
# All-reduce across all the partitions.
|
||||
output_ = reduce_from_model_parallel_region(output_parallel)
|
||||
if self.bias is not None:
|
||||
output = output_ + self.bias
|
||||
else:
|
||||
output = output_
|
||||
return output
|
||||
return output
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
This is mainly adapted from torch.nn.Embedding and all the default values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
@ -495,13 +227,19 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
# Build the mask.
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Performs an embedding lookup for input tokens in the parallelized embedding layer
|
||||
1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
|
||||
2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
|
||||
3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
|
||||
"""
|
||||
# Build the mask for out-of-vocabulary tokens.
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
# Get the embeddings.
|
||||
# Get the embeddings for the valid tokens.
|
||||
output_parallel = F.embedding(
|
||||
masked_input,
|
||||
self.weight,
|
||||
@ -511,9 +249,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
# Mask the output embedding.
|
||||
# Embedding of tokens that are not in the vocabulary is set to 0. do a all reduce at the end.
|
||||
# Embedding of out-of-vocabulary tokens is set to 0.
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
# Reduce across all the model parallel GPUs to get the final output.
|
||||
output = reduce_from_model_parallel_region(output_parallel)
|
||||
return output
|
||||
@ -1,9 +1,10 @@
|
||||
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
|
||||
from functools import partial
|
||||
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
|
||||
import torch.nn.init as init
|
||||
import torch.nn as nn
|
||||
|
||||
class TensorParallel():
|
||||
def __init__(self, model, init_method = init.xavier_normal_):
|
||||
def __init__(self, model, init_method = initialize_weight_tensor):
|
||||
super().__init__()
|
||||
|
||||
module_linear_name_stype_mapping_list = [
|
||||
@ -46,6 +47,6 @@ class TensorParallel():
|
||||
new_linear_layer = VocabParallelEmbedding(
|
||||
num_embeddings=linear_layer.num_embeddings,
|
||||
embedding_dim=linear_layer.embedding_dim,
|
||||
init_method=self.init_method
|
||||
init_method=partial(self.init_method, vocab_embedding=True)
|
||||
)
|
||||
setattr(module, linear_proj_name, new_linear_layer)
|
||||
Loading…
Reference in New Issue
Block a user