tesnsor parallel, will clean later
This commit is contained in:
parent
54ad77e055
commit
7377238741
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,3 +1,6 @@
|
||||
__pycache__
|
||||
*.pth
|
||||
.vscode/
|
||||
.vscode/
|
||||
picotron.egg-info
|
||||
*.ipynb
|
||||
wandb
|
||||
15
model.py
15
model.py
@ -5,6 +5,8 @@ import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
dtype = torch.bfloat16 if os.getenv('DATA_TYPE', 'bfloat16') == 'bfloat16' else torch.float32
|
||||
@ -56,13 +58,13 @@ class CausalSelfAttention(nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_values = config.num_key_value_heads
|
||||
self.head_dim = self.hidden_size//self.num_heads
|
||||
# model_parallel_size = get_model_parallel_world_size()
|
||||
model_parallel_size = 1
|
||||
model_parallel_size = pgm.process_group_manager.tp_world_size
|
||||
self.num_local_heads = config.num_attention_heads // model_parallel_size # TP parallelism
|
||||
self.num_local_kv_heads = config.num_key_value_heads // model_parallel_size # TP parallelism
|
||||
self.is_merged_qkv_weight = os.getenv('MERGED_QKV_WEIGHT', '1')
|
||||
if self.is_merged_qkv_weight == '1':
|
||||
self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim + 2*self.num_key_values*self.head_dim, bias=False)
|
||||
# self.qkv_proj = ColumnParallelLinear(config.hidden_size, self.num_heads*self.head_dim + 2*self.num_key_values*self.head_dim, bias=False, gather_output=False, init_method=init_method)
|
||||
else:
|
||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
|
||||
@ -134,11 +136,12 @@ class LLaMAMLP(nn.Module):
|
||||
self.merged_gate_up = os.getenv('MERGED_GATE_UP_WEIGHT', '1') == '1'
|
||||
if self.merged_gate_up:
|
||||
self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size*2, bias=False)
|
||||
# self.gate_up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size*2, bias=False, gather_output=False, init_method=init_method)
|
||||
else:
|
||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
# self.up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
|
||||
# self.gate_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
|
||||
# self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
# self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
|
||||
self.gate_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
# self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False, input_is_parallel=True, init_method=init_method)
|
||||
|
||||
|
||||
6
setup.py
Normal file
6
setup.py
Normal file
@ -0,0 +1,6 @@
|
||||
from setuptools import setup, find_packages
|
||||
setup(
|
||||
name="picotron", # Name of the package
|
||||
version='0.1.0',
|
||||
packages=find_packages(), # Automatically find packages in the current directory
|
||||
)
|
||||
@ -8,7 +8,7 @@ from src.distributed.distributed_primtives import ContextComms
|
||||
from model import Attention
|
||||
import src.distributed.process_group_manager as pgm
|
||||
|
||||
from parallel.base_parallel import BaseParallel
|
||||
from src.parallel.base_parallel import BaseParallel
|
||||
|
||||
class ContextParallel(BaseParallel):
|
||||
def __init__(self, model, config):
|
||||
|
||||
@ -3,7 +3,7 @@ from src.distributed.distributed_primtives import pipeline_communicate, bidirect
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from parallel.base_parallel import BaseParallel
|
||||
from src.parallel.base_parallel import BaseParallel
|
||||
|
||||
class PipelineParallel(BaseParallel):
|
||||
def __init__(self, model, config):
|
||||
|
||||
@ -1 +0,0 @@
|
||||
#TODO
|
||||
0
src/parallel/tensor_parallel/__init__.py
Normal file
0
src/parallel/tensor_parallel/__init__.py
Normal file
519
src/parallel/tensor_parallel/layers.py
Normal file
519
src/parallel/tensor_parallel/layers.py
Normal file
@ -0,0 +1,519 @@
|
||||
"""
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from src.parallel.tensor_parallel.utils import VocabUtility, divide_and_check_no_remainder
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Callable, Optional
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
|
||||
# def _initialize_affine_weight(
|
||||
# weight: torch.Tensor,
|
||||
# out_features: int,
|
||||
# in_features: int,
|
||||
# per_partition_size: int,
|
||||
# partition_dim: int,
|
||||
# init_method: Callable[[torch.Tensor], torch.Tensor]
|
||||
# ) -> Optional[torch.Tensor]:
|
||||
# """
|
||||
# Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
|
||||
# Args:
|
||||
# weight: The weight tensor that will be initialized for the current partition.
|
||||
# out_features: second dimension of weight matrix W.
|
||||
# in_features: first dimension of weight matrix W.
|
||||
# per_partition_size: The size of the weight partition assigned to each process.
|
||||
# partition_dim: The dimension along which the weight matrix is split for parallelism.
|
||||
# init_method: The method used to initialize the weight values.
|
||||
# """
|
||||
|
||||
# # If we only use 1 process for model parallelism, we can simply initialize the weight
|
||||
# if pgm.process_group_manager.tp_world_size == 1:
|
||||
# init_method(weight)
|
||||
# return None
|
||||
|
||||
# # Initialize master weight
|
||||
# master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
|
||||
# init_method(master_weight)
|
||||
|
||||
# # Split the model into size of per_partition_size and take the corresponding partition
|
||||
# weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
|
||||
# weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
|
||||
|
||||
# return None
|
||||
|
||||
# class ColumnParallelLinear(torch.nn.Module):
|
||||
# """Column Parallel Linear layer
|
||||
# Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
|
||||
# This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
|
||||
# Arguments:
|
||||
# in_features: first dimension of weight matrix W.
|
||||
# out_features: second dimension of weight matrix W.
|
||||
# bias: If true, add bias
|
||||
# init_method: method to initialize weights
|
||||
# gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_features: int,
|
||||
# out_features: int,
|
||||
# bias: bool = False,
|
||||
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
# gather_output: bool = False,
|
||||
# ) -> None:
|
||||
# super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
# self.in_features = in_features
|
||||
# self.out_features = out_features
|
||||
# assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
|
||||
# self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
|
||||
# self.gather_output = gather_output
|
||||
|
||||
# # Allocate space for the weight and bias
|
||||
# # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
|
||||
# self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
|
||||
# if bias:
|
||||
# self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
|
||||
# # Always initialize bias to zero.
|
||||
# with torch.no_grad():
|
||||
# self.bias.zero_()
|
||||
# else:
|
||||
# self.register_parameter("bias", None)
|
||||
|
||||
# # Initialize weight.
|
||||
# _initialize_affine_weight(
|
||||
# self.weight,
|
||||
# self.out_features,
|
||||
# self.in_features,
|
||||
# self.output_size_per_partition,
|
||||
# partition_dim = 0,
|
||||
# init_method = init_method,
|
||||
# )
|
||||
|
||||
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# input_parallel = copy_to_model_parallel_region(input_)
|
||||
# output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
|
||||
# if self.gather_output:
|
||||
# output = gather_from_model_parallel_region(output)
|
||||
# return output
|
||||
|
||||
# class RowParallelLinear(torch.nn.Module):
|
||||
# """Linear layer with row parallelism.
|
||||
# Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
|
||||
# - -
|
||||
# | W_1 |
|
||||
# | . |
|
||||
# W = | . | X = [X_1, ..., X_p]
|
||||
# | . |
|
||||
# | W_p |
|
||||
# - -
|
||||
# We assume that X is already parallelized. This is the case after ColumnParallelLinear.
|
||||
# This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
|
||||
# Arguments:
|
||||
# in_features: first dimension of matrix W.
|
||||
# out_features: second dimension of matrix W.
|
||||
# bias: If true, add bias
|
||||
# init_method: method to initialize weights.
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_features: int,
|
||||
# out_features: int,
|
||||
# bias: bool = True,
|
||||
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
# ):
|
||||
# super(RowParallelLinear, self).__init__()
|
||||
|
||||
# # Keep input parameters
|
||||
# self.in_features = in_features
|
||||
# self.out_features = out_features
|
||||
# self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
|
||||
|
||||
# self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
|
||||
# if bias:
|
||||
# self.bias = Parameter(torch.Tensor(self.out_features))
|
||||
# # Always initialize bias to zero.
|
||||
# with torch.no_grad():
|
||||
# self.bias.zero_()
|
||||
# else:
|
||||
# self.register_parameter("bias", None)
|
||||
|
||||
# # Initialize weight.
|
||||
# _initialize_affine_weight(
|
||||
# self.weight,
|
||||
# self.out_features,
|
||||
# self.in_features,
|
||||
# self.input_size_per_partition,
|
||||
# partition_dim = 1,
|
||||
# init_method = init_method,
|
||||
# )
|
||||
|
||||
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
|
||||
# # All-reduce across all the partitions.
|
||||
# output_ = reduce_from_model_parallel_region(output_parallel)
|
||||
# if self.bias is not None:
|
||||
# output = output_ + self.bias
|
||||
# else:
|
||||
# output = output_
|
||||
# return output
|
||||
|
||||
# class VocabParallelEmbedding(torch.nn.Module):
|
||||
# """Embedding parallelized in the vocabulary dimension.
|
||||
# This is mainly adapted from torch.nn.Embedding and all the default values are kept.
|
||||
# Arguments:
|
||||
# num_embeddings: vocabulary size.
|
||||
# embedding_dim: size of hidden state.
|
||||
# init_method: method to initialize weights.
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# num_embeddings: int,
|
||||
# embedding_dim: int,
|
||||
# padding_idx: Optional[int] = None,
|
||||
# max_norm: Optional[float] = None,
|
||||
# norm_type: float = 2.0,
|
||||
# scale_grad_by_freq: bool = False,
|
||||
# sparse: bool = False,
|
||||
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
# ) -> None:
|
||||
# super(VocabParallelEmbedding, self).__init__()
|
||||
# # Keep the input dimensions.
|
||||
# self.num_embeddings = num_embeddings
|
||||
# self.embedding_dim = embedding_dim
|
||||
# self.padding_idx = padding_idx
|
||||
# self.max_norm = max_norm
|
||||
# self.norm_type = norm_type
|
||||
# self.scale_grad_by_freq = scale_grad_by_freq
|
||||
# self.sparse = sparse
|
||||
# self._weight = None
|
||||
# # Divide the weight matrix along the vocaburaly dimension.
|
||||
# self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
|
||||
# self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
|
||||
# )
|
||||
# self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
|
||||
|
||||
# # Allocate weights.
|
||||
# self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
|
||||
# # And initialize.
|
||||
# _initialize_affine_weight(
|
||||
# self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
|
||||
# )
|
||||
|
||||
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# """
|
||||
# Performs an embedding lookup for input tokens in the parallelized embedding layer
|
||||
# 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
|
||||
# 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
|
||||
# 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
|
||||
# """
|
||||
# # Build the mask for out-of-vocabulary tokens.
|
||||
# input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
# # Mask the input.
|
||||
# masked_input = input_.clone() - self.vocab_start_index
|
||||
# masked_input[input_mask] = 0
|
||||
# # Get the embeddings for the valid tokens.
|
||||
# output_parallel = F.embedding(
|
||||
# masked_input,
|
||||
# self.weight,
|
||||
# self.padding_idx,
|
||||
# self.max_norm,
|
||||
# self.norm_type,
|
||||
# self.scale_grad_by_freq,
|
||||
# self.sparse,
|
||||
# )
|
||||
# # Embedding of out-of-vocabulary tokens is set to 0.
|
||||
# output_parallel[input_mask, :] = 0.0
|
||||
# # Reduce across all the model parallel GPUs to get the final output.
|
||||
# output = reduce_from_model_parallel_region(output_parallel)
|
||||
# return output
|
||||
|
||||
|
||||
def _initialize_affine_weight(
|
||||
weight: torch.Tensor,
|
||||
out_features: int,
|
||||
in_features: int,
|
||||
per_partition_size: int,
|
||||
partition_dim: int,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor],
|
||||
stride: int = 1,
|
||||
return_master_weight: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Initialize affine weight for model parallel.
|
||||
|
||||
Build the master weight on all processes and scatter
|
||||
the relevant chunk."""
|
||||
|
||||
# If we only use 1 process for model parallelism, bypass scatter.
|
||||
world_size = pgm.process_group_manager.world_size
|
||||
if world_size == 1:
|
||||
init_method(weight)
|
||||
if return_master_weight:
|
||||
return weight
|
||||
return None
|
||||
|
||||
# Initialize master weight
|
||||
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
|
||||
# init_method(master_weight)
|
||||
|
||||
k = 1.0 / in_features
|
||||
bound = torch.sqrt(torch.tensor(k, dtype=master_weight.dtype))
|
||||
|
||||
# Use PyTorch's built-in uniform initialization
|
||||
init.uniform_(master_weight, -bound.item(), bound.item())
|
||||
|
||||
# Split and copy
|
||||
per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride)
|
||||
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
|
||||
rank = pgm.process_group_manager.tp_rank
|
||||
my_weight_list = weight_list[rank::world_size]
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
||||
if return_master_weight:
|
||||
return master_weight
|
||||
return None
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
Arguments:
|
||||
in_features: first dimension of matrix A.
|
||||
out_features: second dimension of matrix A.
|
||||
bias: If true, add bias
|
||||
gather_output: If true, call all-gather on output and make Y avaiable
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
gather_output: bool = True,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
stride: int = 1,
|
||||
keep_master_weight_for_test: bool = False,
|
||||
) -> None:
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = pgm.process_group_manager.tp_world_size
|
||||
self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size)
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Initialize weight.
|
||||
self.master_weight = _initialize_affine_weight(
|
||||
self.weight,
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.output_size_per_partition,
|
||||
0,
|
||||
init_method,
|
||||
stride=stride,
|
||||
return_master_weight=keep_master_weight_for_test,
|
||||
)
|
||||
|
||||
def get_master_weight(self) -> torch.Tensor:
|
||||
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
# Backprop: all-reduce.
|
||||
input_parallel = copy_to_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = F.linear(input_parallel, self.weight, self.bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_from_model_parallel_region(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
return output
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
Arguments:
|
||||
in_features: first dimension of matrix A.
|
||||
out_features: second dimension of matrix A.
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True, # Normally, input is parallelized, especially in Attention projection/MLP. There is a column parallel before it.
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
stride: int = 1,
|
||||
keep_master_weight_for_test: bool = False,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.input_is_parallel = input_is_parallel
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = pgm.process_group_manager.tp_world_size
|
||||
self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size)
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(self.out_features))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Initialize weight.
|
||||
self.master_weight = _initialize_affine_weight(
|
||||
self.weight,
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.input_size_per_partition,
|
||||
1,
|
||||
init_method,
|
||||
stride=stride,
|
||||
return_master_weight=keep_master_weight_for_test,
|
||||
)
|
||||
|
||||
def get_master_weight(self) -> torch.Tensor:
|
||||
return gather_from_model_parallel_region(self.weight.data)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
|
||||
# Set up backprop all-reduce.
|
||||
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply
|
||||
output_parallel = F.linear(input_parallel, self.weight)
|
||||
# All-reduce across all the partitions.
|
||||
output_ = reduce_from_model_parallel_region(output_parallel)
|
||||
if self.bias is not None:
|
||||
output = output_ + self.bias
|
||||
else:
|
||||
output = output_
|
||||
return output
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
) -> None:
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
self._weight = None
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
|
||||
)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
|
||||
|
||||
# Allocate weights.
|
||||
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
|
||||
# And initialize.
|
||||
_initialize_affine_weight(
|
||||
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(
|
||||
masked_input,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
# Mask the output embedding.
|
||||
# Embedding of tokens that are not in the vocabulary is set to 0. do a all reduce at the end.
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_from_model_parallel_region(output_parallel)
|
||||
return output
|
||||
96
src/parallel/tensor_parallel/mappings.py
Normal file
96
src/parallel/tensor_parallel/mappings.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from src.parallel.tensor_parallel.utils import split_tensor_along_last_dim
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import src.distributed.process_group_manager as pgm
|
||||
|
||||
def _reduce(input_):
|
||||
"""All-reduce the input tensor across model parallel(Tensor Parallel) group."""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if pgm.process_group_manager.tp_size == 1:
|
||||
return input_
|
||||
|
||||
# All-reduce across the tensor parallel group
|
||||
dist.all_reduce(input_, group=pgm.process_group_manager.tp_group)
|
||||
|
||||
return input_
|
||||
|
||||
class _CopyToModelParallelRegion(torch.autograd.Function):
|
||||
"""copy(identity) in forward pass, all reduce in backward pass"""
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output)
|
||||
|
||||
class _ReduceFromModelParallelRegion(torch.autograd.Function):
|
||||
"""all reduce in forward pass, copy(identity) in backward pass"""
|
||||
@staticmethod
|
||||
def forward(ctx, input_): # type: ignore
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # type: ignore
|
||||
return grad_output
|
||||
|
||||
# This is the `f` function in the paper: https://arxiv.org/abs/1909.08053
|
||||
def copy_to_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
|
||||
return _CopyToModelParallelRegion.apply(input_)
|
||||
|
||||
# This is the `g` function in the paper, which is the conjugate of `f`
|
||||
def reduce_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
|
||||
return _ReduceFromModelParallelRegion.apply(input_)
|
||||
|
||||
def _split(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""Split the tensor along its last dimension and keep the corresponding slice."""
|
||||
tp_rank = pgm.process_group_manager.tp_rank
|
||||
tp_world_size = pgm.process_group_manager.tp_size
|
||||
|
||||
# Bypass the function if we are using only 1 GPU
|
||||
if tp_world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension and keep the corresponding slice
|
||||
input_list = split_tensor_along_last_dim(input_, tp_world_size)
|
||||
output = input_list[tp_rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
def _gather(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""Gather tensors and concatinate along the last dimension."""
|
||||
tp_rank = pgm.process_group_manager.tp_rank
|
||||
tp_world_size = pgm.process_group_manager.tp_size
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if tp_world_size == 1:
|
||||
return input_
|
||||
|
||||
# Size and dimension.
|
||||
last_dim = input_.dim() - 1
|
||||
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(tp_world_size)]
|
||||
tensor_list[tp_rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=pgm.process_group_manager.tp_group)
|
||||
|
||||
output = torch.cat(tensor_list, dim=last_dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
class _GatherFromModelParallelRegion(torch.autograd.Function):
|
||||
"""Gathher in the forward pass, split in the backward pass."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _gather(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output)
|
||||
|
||||
def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
|
||||
return _GatherFromModelParallelRegion.apply(input_)
|
||||
51
src/parallel/tensor_parallel/tensor_parallel.py
Normal file
51
src/parallel/tensor_parallel/tensor_parallel.py
Normal file
@ -0,0 +1,51 @@
|
||||
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
|
||||
import torch.nn.init as init
|
||||
import torch.nn as nn
|
||||
|
||||
class TensorParallel(nn.Module):
|
||||
def __init__(self, model, init_method = init.xavier_normal_):
|
||||
module_linear_name_stype_mapping_list = [
|
||||
("attention", "qkv_proj", "column"),
|
||||
("attention", "out_proj", "row"),
|
||||
("mlp", "gate_up_proj", "column"),
|
||||
("mlp", "down_proj", "row"),
|
||||
]
|
||||
|
||||
self.init_method = init_method
|
||||
|
||||
for layer in model.decoder_layers:
|
||||
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
|
||||
self.replace_module(getattr(layer, module_name), linear_proj_name, style)
|
||||
self.replace_module(model, "embedding", "vocab")
|
||||
self.replace_module(model, "final_proj", "column", args={"gather_output": True})
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# print(name, param.shape, param.requires_grad)
|
||||
|
||||
|
||||
def replace_module(self,module, linear_proj_name, style, args = {}):
|
||||
assert style in ["column", "row", 'vocab']
|
||||
linear_layer = getattr(module, linear_proj_name)
|
||||
if style == "column":
|
||||
new_linear_layer = ColumnParallelLinear(
|
||||
in_features=linear_layer.in_features,
|
||||
out_features=linear_layer.out_features,
|
||||
bias=linear_layer.bias is not None,
|
||||
init_method=self.init_method,
|
||||
gather_output=args.get("gather_output", False)
|
||||
)
|
||||
elif style == "row":
|
||||
new_linear_layer = RowParallelLinear(
|
||||
in_features=linear_layer.in_features,
|
||||
out_features=linear_layer.out_features,
|
||||
bias=linear_layer.bias is not None,
|
||||
init_method=self.init_method
|
||||
)
|
||||
else:
|
||||
new_linear_layer = VocabParallelEmbedding(
|
||||
num_embeddings=linear_layer.num_embeddings,
|
||||
embedding_dim=linear_layer.embedding_dim,
|
||||
init_method=self.init_method
|
||||
)
|
||||
setattr(module, linear_proj_name, new_linear_layer)
|
||||
|
||||
47
src/parallel/tensor_parallel/utils.py
Normal file
47
src/parallel/tensor_parallel/utils.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
def divide_and_check_no_remainder(numerator: int, denominator: int) -> int:
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
|
||||
return numerator // denominator
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor, num_partitions: int
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""Split a tensor along its last dimension.
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide_and_check_no_remainder(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
return tensor_list
|
||||
|
||||
class VocabUtility:
|
||||
"""Split the vocabulary into `world_size` chunks amd return the
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [first, last)"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int, rank: int
|
||||
) -> Tuple[int, int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Tuple[int, int]:
|
||||
per_partition_vocab_size = divide_and_check_no_remainder(global_vocab_size, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
|
||||
52
train.py
52
train.py
@ -1,8 +1,9 @@
|
||||
"""Training script for LLaMA model.
|
||||
torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py
|
||||
torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb
|
||||
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --tp_size 2
|
||||
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 2
|
||||
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 1 --dp_size 2
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --tp_size 2
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
|
||||
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
||||
"""
|
||||
@ -19,25 +20,26 @@ from datasets import load_dataset
|
||||
import argparse
|
||||
from datasets import Features, Sequence, Value
|
||||
import numpy as np
|
||||
from src.parallel.tensor_parallel.tensor_parallel import TensorParallel
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from utils import set_all_seed, print
|
||||
from src.distributed.process_group_manager import setup_process_group_manager
|
||||
from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from dataclasses import dataclass
|
||||
from src.parallel.data_parallel.data_parallel_bucket import DataParallel
|
||||
from src.parallel.context_parallel import ContextParallel
|
||||
# from src.parallel.context_parallel import ContextParallel
|
||||
from model import LLaMA
|
||||
import wandb
|
||||
import multiprocessing
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None, num_workers=0):
|
||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, grad_acc = 1, split="train", num_samples=None, num_workers=0):
|
||||
self.global_batch_size = global_batch_size
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.seq_length = seq_length
|
||||
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size # each DP rank gets a local batch
|
||||
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
self.grad_acc = grad_acc
|
||||
|
||||
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
|
||||
|
||||
@ -137,17 +139,19 @@ def train_step(model, data_loader, device):
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
target_ids = batch["target_ids"].to(device)
|
||||
|
||||
outputs = model(input_ids=input_ids)
|
||||
for i in range(data_loader.grad_acc):
|
||||
outputs = model(input_ids=input_ids)
|
||||
|
||||
# compute the loss
|
||||
batch_size, seq_len = input_ids.shape
|
||||
target_ids = target_ids.reshape(-1)
|
||||
outputs = outputs.view(seq_len*batch_size, -1)
|
||||
loss = F.cross_entropy(outputs, target_ids, reduction='mean')
|
||||
|
||||
loss.backward()
|
||||
# compute the loss
|
||||
batch_size, seq_len = input_ids.shape
|
||||
target_ids = target_ids.reshape(-1)
|
||||
outputs = outputs.view(seq_len*batch_size, -1)
|
||||
loss = F.cross_entropy(outputs, target_ids, reduction='mean')
|
||||
|
||||
loss.backward()
|
||||
|
||||
acc_loss += loss.item()
|
||||
acc_loss += loss.item()
|
||||
acc_loss /= data_loader.grad_acc
|
||||
|
||||
return acc_loss
|
||||
|
||||
@ -175,6 +179,7 @@ if __name__ == "__main__":
|
||||
# SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
|
||||
## hyperparameters
|
||||
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 16, 4, 3e-4, 100000, int(10e8), 42
|
||||
grad_acc = 16
|
||||
|
||||
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
|
||||
|
||||
@ -197,10 +202,12 @@ if __name__ == "__main__":
|
||||
dataset_name = "roneneldan/TinyStories"
|
||||
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config.num_attention_heads = 16
|
||||
config.num_key_value_heads = 4
|
||||
|
||||
model = LLaMA(
|
||||
config=config
|
||||
).to(device)
|
||||
)
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
|
||||
wandb.init(
|
||||
@ -220,8 +227,11 @@ if __name__ == "__main__":
|
||||
},
|
||||
)
|
||||
|
||||
if pgm.process_group_manager.cp_size > 1:
|
||||
model = ContextParallel(model, config)
|
||||
if pgm.process_group_manager.tp_world_size > 1:
|
||||
TensorParallel(model)
|
||||
|
||||
# if pgm.process_group_manager.cp_size > 1:
|
||||
# model = ContextParallel(model, config)
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, config)
|
||||
@ -229,17 +239,15 @@ if __name__ == "__main__":
|
||||
if pgm.process_group_manager.dp_world_size > 1:
|
||||
model = DataParallel(model, pgm.process_group_manager.dp_group)
|
||||
|
||||
# if pgm.process_group_manager.tp_world_size > 1:
|
||||
# model = TensorParallel(model, config)
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES)
|
||||
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, grad_acc = grad_acc, num_samples=NUM_SAMPLES)
|
||||
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size)
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
trained_tokens, step = 0, 0
|
||||
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
|
||||
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN * grad_acc
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user