match TP loss

This commit is contained in:
zzhhjjj 2024-10-27 02:22:05 +00:00
parent 51b5683dd3
commit e5cfb5240e
2 changed files with 91 additions and 353 deletions

View File

@ -2,8 +2,9 @@
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
""" """
from src.parallel.tensor_parallel.utils import VocabUtility, divide_and_check_no_remainder from src.parallel.tensor_parallel.utils import VocabUtility
import torch import torch
import math
import torch.nn.init as init import torch.nn.init as init
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -11,229 +12,25 @@ from typing import Callable, Optional
import src.distributed.process_group_manager as pgm import src.distributed.process_group_manager as pgm
from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
# def _initialize_affine_weight( def initialize_weight_tensor(weight, vocab_embedding=False):
# weight: torch.Tensor, """
# out_features: int, Initialize the weight tensor with the default initialization method in PyTorch
# in_features: int, If not a vocab embedding, it uses U(-sqrt(k), sqrt(k)) with k = 1/in_features.
# per_partition_size: int, If it's a vocab embedding, it uses a normal distribution N(0, 1).
# partition_dim: int, """
# init_method: Callable[[torch.Tensor], torch.Tensor] if not vocab_embedding:
# ) -> Optional[torch.Tensor]: # Get the in_features from the shape of the weight tensor
# """ _, in_features = weight.shape
# Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
# Args: # Calculate k and the uniform bounds
# weight: The weight tensor that will be initialized for the current partition. k = 1 / in_features
# out_features: second dimension of weight matrix W. bound = math.sqrt(k)
# in_features: first dimension of weight matrix W.
# per_partition_size: The size of the weight partition assigned to each process. # Initialize weights with U(-sqrt(k), sqrt(k))
# partition_dim: The dimension along which the weight matrix is split for parallelism. torch.nn.init.uniform_(weight, -bound, bound)
# init_method: The method used to initialize the weight values. else:
# """ # Initialize Vocab embedding with N(0, 1)
torch.nn.init.normal_(weight, mean=0.0, std=1.0)
# # If we only use 1 process for model parallelism, we can simply initialize the weight
# if pgm.process_group_manager.tp_world_size == 1:
# init_method(weight)
# return None
# # Initialize master weight
# master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
# init_method(master_weight)
# # Split the model into size of per_partition_size and take the corresponding partition
# weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
# weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
# return None
# class ColumnParallelLinear(torch.nn.Module):
# """Column Parallel Linear layer
# Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
# This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
# Arguments:
# in_features: first dimension of weight matrix W.
# out_features: second dimension of weight matrix W.
# bias: If true, add bias
# init_method: method to initialize weights
# gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
# """
# def __init__(
# self,
# in_features: int,
# out_features: int,
# bias: bool = False,
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
# gather_output: bool = False,
# ) -> None:
# super(ColumnParallelLinear, self).__init__()
# self.in_features = in_features
# self.out_features = out_features
# assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
# self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
# self.gather_output = gather_output
# # Allocate space for the weight and bias
# # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
# self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
# if bias:
# self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# # Always initialize bias to zero.
# with torch.no_grad():
# self.bias.zero_()
# else:
# self.register_parameter("bias", None)
# # Initialize weight.
# _initialize_affine_weight(
# self.weight,
# self.out_features,
# self.in_features,
# self.output_size_per_partition,
# partition_dim = 0,
# init_method = init_method,
# )
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
# input_parallel = copy_to_model_parallel_region(input_)
# output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
# if self.gather_output:
# output = gather_from_model_parallel_region(output)
# return output
# class RowParallelLinear(torch.nn.Module):
# """Linear layer with row parallelism.
# Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
# - -
# | W_1 |
# | . |
# W = | . | X = [X_1, ..., X_p]
# | . |
# | W_p |
# - -
# We assume that X is already parallelized. This is the case after ColumnParallelLinear.
# This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
# Arguments:
# in_features: first dimension of matrix W.
# out_features: second dimension of matrix W.
# bias: If true, add bias
# init_method: method to initialize weights.
# """
# def __init__(
# self,
# in_features: int,
# out_features: int,
# bias: bool = True,
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
# ):
# super(RowParallelLinear, self).__init__()
# # Keep input parameters
# self.in_features = in_features
# self.out_features = out_features
# self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
# self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
# if bias:
# self.bias = Parameter(torch.Tensor(self.out_features))
# # Always initialize bias to zero.
# with torch.no_grad():
# self.bias.zero_()
# else:
# self.register_parameter("bias", None)
# # Initialize weight.
# _initialize_affine_weight(
# self.weight,
# self.out_features,
# self.in_features,
# self.input_size_per_partition,
# partition_dim = 1,
# init_method = init_method,
# )
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
# output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
# # All-reduce across all the partitions.
# output_ = reduce_from_model_parallel_region(output_parallel)
# if self.bias is not None:
# output = output_ + self.bias
# else:
# output = output_
# return output
# class VocabParallelEmbedding(torch.nn.Module):
# """Embedding parallelized in the vocabulary dimension.
# This is mainly adapted from torch.nn.Embedding and all the default values are kept.
# Arguments:
# num_embeddings: vocabulary size.
# embedding_dim: size of hidden state.
# init_method: method to initialize weights.
# """
# def __init__(
# self,
# num_embeddings: int,
# embedding_dim: int,
# padding_idx: Optional[int] = None,
# max_norm: Optional[float] = None,
# norm_type: float = 2.0,
# scale_grad_by_freq: bool = False,
# sparse: bool = False,
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
# ) -> None:
# super(VocabParallelEmbedding, self).__init__()
# # Keep the input dimensions.
# self.num_embeddings = num_embeddings
# self.embedding_dim = embedding_dim
# self.padding_idx = padding_idx
# self.max_norm = max_norm
# self.norm_type = norm_type
# self.scale_grad_by_freq = scale_grad_by_freq
# self.sparse = sparse
# self._weight = None
# # Divide the weight matrix along the vocaburaly dimension.
# self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
# self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
# )
# self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# # Allocate weights.
# self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
# # And initialize.
# _initialize_affine_weight(
# self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
# )
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
# """
# Performs an embedding lookup for input tokens in the parallelized embedding layer
# 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
# 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
# 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
# """
# # Build the mask for out-of-vocabulary tokens.
# input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# # Mask the input.
# masked_input = input_.clone() - self.vocab_start_index
# masked_input[input_mask] = 0
# # Get the embeddings for the valid tokens.
# output_parallel = F.embedding(
# masked_input,
# self.weight,
# self.padding_idx,
# self.max_norm,
# self.norm_type,
# self.scale_grad_by_freq,
# self.sparse,
# )
# # Embedding of out-of-vocabulary tokens is set to 0.
# output_parallel[input_mask, :] = 0.0
# # Reduce across all the model parallel GPUs to get the final output.
# output = reduce_from_model_parallel_region(output_parallel)
# return output
def _initialize_affine_weight( def _initialize_affine_weight(
weight: torch.Tensor, weight: torch.Tensor,
@ -241,64 +38,44 @@ def _initialize_affine_weight(
in_features: int, in_features: int,
per_partition_size: int, per_partition_size: int,
partition_dim: int, partition_dim: int,
init_method: Callable[[torch.Tensor], torch.Tensor], init_method: Callable[[torch.Tensor], torch.Tensor]
stride: int = 1,
return_master_weight: bool = False,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
"""Initialize affine weight for model parallel. """
Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
Args:
weight: The weight tensor that will be initialized for the current partition.
out_features: second dimension of weight matrix W.
in_features: first dimension of weight matrix W.
per_partition_size: The size of the weight partition assigned to each process.
partition_dim: The dimension along which the weight matrix is split for parallelism.
init_method: The method used to initialize the weight values.
"""
Build the master weight on all processes and scatter # If we only use 1 process for model parallelism, we can simply initialize the weight
the relevant chunk.""" if pgm.process_group_manager.tp_world_size == 1:
# If we only use 1 process for model parallelism, bypass scatter.
world_size = pgm.process_group_manager.world_size
if world_size == 1:
init_method(weight) init_method(weight)
if return_master_weight:
return weight
return None return None
# Initialize master weight # Initialize master weight
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False) master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
# init_method(master_weight) init_method(master_weight)
k = 1.0 / in_features # Split the model into size of per_partition_size and take the corresponding partition
bound = torch.sqrt(torch.tensor(k, dtype=master_weight.dtype)) weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
# Use PyTorch's built-in uniform initialization
init.uniform_(master_weight, -bound.item(), bound.item())
# Split and copy
per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
rank = pgm.process_group_manager.tp_rank
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None return None
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Column Parallel Linear layer
Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
The linear layer is defined as Y = XA + b. A is parallelized along This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
its second dimension as A = [A_1, ..., A_p].
Arguments: Arguments:
in_features: first dimension of matrix A. in_features: first dimension of weight matrix W.
out_features: second dimension of matrix A. out_features: second dimension of weight matrix W.
bias: If true, add bias bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable init_method: method to initialize weights
to all GPUs, otherwise, every GPU will have its output gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
""" """
def __init__( def __init__(
@ -306,25 +83,20 @@ class ColumnParallelLinear(torch.nn.Module):
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = False, bias: bool = False,
gather_output: bool = True,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1, gather_output: bool = False,
keep_master_weight_for_test: bool = False,
) -> None: ) -> None:
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
# Keep input parameters
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = pgm.process_group_manager.tp_world_size
self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size)
# Parameters. # Allocate space for the weight and bias
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
# we allocate the transpose. self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features))
if bias: if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# Always initialize bias to zero. # Always initialize bias to zero.
@ -334,58 +106,39 @@ class ColumnParallelLinear(torch.nn.Module):
self.register_parameter("bias", None) self.register_parameter("bias", None)
# Initialize weight. # Initialize weight.
self.master_weight = _initialize_affine_weight( _initialize_affine_weight(
self.weight, self.weight,
self.out_features, self.out_features,
self.in_features, self.in_features,
self.output_size_per_partition, self.output_size_per_partition,
0, partition_dim = 0,
init_method, init_method = init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
) )
def get_master_weight(self) -> torch.Tensor: def forward(self, input_: torch.Tensor) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Backprop: all-reduce.
input_parallel = copy_to_model_parallel_region(input_) input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply. output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. output = gather_from_model_parallel_region(output)
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
return output return output
class RowParallelLinear(torch.nn.Module): class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- - - -
| A_1 | | W_1 |
| . | | . |
A = | . | X = [X_1, ..., X_p] W = | . | X = [X_1, ..., X_p]
| . | | . |
| A_p | | W_p |
- - - -
We assume that X is already parallelized. This is the case after ColumnParallelLinear.
This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
Arguments: Arguments:
in_features: first dimension of matrix A. in_features: first dimension of matrix W.
out_features: second dimension of matrix A. out_features: second dimension of matrix W.
bias: If true, add bias. Note that bias is not parallelized. bias: If true, add bias
input_is_parallel: If true, we assume that the input is already init_method: method to initialize weights.
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
""" """
def __init__( def __init__(
@ -393,24 +146,15 @@ class RowParallelLinear(torch.nn.Module):
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
input_is_parallel: bool = True, # Normally, input is parallelized, especially in Attention projection/MLP. There is a column parallel before it.
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
): ):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
# Keep input parameters # Keep input parameters
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.input_is_parallel = input_is_parallel self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
# Divide the weight matrix along the last dimension.
world_size = pgm.process_group_manager.tp_world_size
self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
if bias: if bias:
self.bias = Parameter(torch.Tensor(self.out_features)) self.bias = Parameter(torch.Tensor(self.out_features))
@ -421,40 +165,28 @@ class RowParallelLinear(torch.nn.Module):
self.register_parameter("bias", None) self.register_parameter("bias", None)
# Initialize weight. # Initialize weight.
self.master_weight = _initialize_affine_weight( _initialize_affine_weight(
self.weight, self.weight,
self.out_features, self.out_features,
self.in_features, self.in_features,
self.input_size_per_partition, self.input_size_per_partition,
1, partition_dim = 1,
init_method, init_method = init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
) )
def get_master_weight(self) -> torch.Tensor: def forward(self, input_: torch.Tensor) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data) output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce.
input_parallel = input_
# Matrix multiply
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions. # All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel) output_ = reduce_from_model_parallel_region(output_parallel)
if self.bias is not None: if self.bias is not None:
output = output_ + self.bias output = output_ + self.bias
else: else:
output = output_ output = output_
return output return output
class VocabParallelEmbedding(torch.nn.Module): class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default values are kept.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments: Arguments:
num_embeddings: vocabulary size. num_embeddings: vocabulary size.
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
@ -495,13 +227,19 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
) )
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore def forward(self, input_: torch.Tensor) -> torch.Tensor:
# Build the mask. """
Performs an embedding lookup for input tokens in the parallelized embedding layer
1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
"""
# Build the mask for out-of-vocabulary tokens.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input. # Mask the input.
masked_input = input_.clone() - self.vocab_start_index masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0 masked_input[input_mask] = 0
# Get the embeddings. # Get the embeddings for the valid tokens.
output_parallel = F.embedding( output_parallel = F.embedding(
masked_input, masked_input,
self.weight, self.weight,
@ -511,9 +249,8 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq, self.scale_grad_by_freq,
self.sparse, self.sparse,
) )
# Mask the output embedding. # Embedding of out-of-vocabulary tokens is set to 0.
# Embedding of tokens that are not in the vocabulary is set to 0. do a all reduce at the end.
output_parallel[input_mask, :] = 0.0 output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs to get the final output.
output = reduce_from_model_parallel_region(output_parallel) output = reduce_from_model_parallel_region(output_parallel)
return output return output

View File

@ -1,9 +1,10 @@
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from functools import partial
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
import torch.nn.init as init import torch.nn.init as init
import torch.nn as nn import torch.nn as nn
class TensorParallel(): class TensorParallel():
def __init__(self, model, init_method = init.xavier_normal_): def __init__(self, model, init_method = initialize_weight_tensor):
super().__init__() super().__init__()
module_linear_name_stype_mapping_list = [ module_linear_name_stype_mapping_list = [
@ -46,6 +47,6 @@ class TensorParallel():
new_linear_layer = VocabParallelEmbedding( new_linear_layer = VocabParallelEmbedding(
num_embeddings=linear_layer.num_embeddings, num_embeddings=linear_layer.num_embeddings,
embedding_dim=linear_layer.embedding_dim, embedding_dim=linear_layer.embedding_dim,
init_method=self.init_method init_method=partial(self.init_method, vocab_embedding=True)
) )
setattr(module, linear_proj_name, new_linear_layer) setattr(module, linear_proj_name, new_linear_layer)