tesnsor parallel, will clean later

This commit is contained in:
zzhhjjj 2024-10-18 05:13:44 +00:00
parent 54ad77e055
commit 7377238741
12 changed files with 764 additions and 32 deletions

5
.gitignore vendored
View File

@ -1,3 +1,6 @@
__pycache__
*.pth
.vscode/
.vscode/
picotron.egg-info
*.ipynb
wandb

View File

@ -5,6 +5,8 @@ import torch.nn.functional as F
import torch.nn.init as init
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.layers.rotary import apply_rotary_emb
import src.distributed.process_group_manager as pgm
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if os.getenv('DATA_TYPE', 'bfloat16') == 'bfloat16' else torch.float32
@ -56,13 +58,13 @@ class CausalSelfAttention(nn.Module):
self.num_heads = config.num_attention_heads
self.num_key_values = config.num_key_value_heads
self.head_dim = self.hidden_size//self.num_heads
# model_parallel_size = get_model_parallel_world_size()
model_parallel_size = 1
model_parallel_size = pgm.process_group_manager.tp_world_size
self.num_local_heads = config.num_attention_heads // model_parallel_size # TP parallelism
self.num_local_kv_heads = config.num_key_value_heads // model_parallel_size # TP parallelism
self.is_merged_qkv_weight = os.getenv('MERGED_QKV_WEIGHT', '1')
if self.is_merged_qkv_weight == '1':
self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim + 2*self.num_key_values*self.head_dim, bias=False)
# self.qkv_proj = ColumnParallelLinear(config.hidden_size, self.num_heads*self.head_dim + 2*self.num_key_values*self.head_dim, bias=False, gather_output=False, init_method=init_method)
else:
self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
@ -134,11 +136,12 @@ class LLaMAMLP(nn.Module):
self.merged_gate_up = os.getenv('MERGED_GATE_UP_WEIGHT', '1') == '1'
if self.merged_gate_up:
self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size*2, bias=False)
# self.gate_up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size*2, bias=False, gather_output=False, init_method=init_method)
else:
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
# self.up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
# self.gate_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
# self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
# self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
self.gate_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
# self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False, input_is_parallel=True, init_method=init_method)

6
setup.py Normal file
View File

@ -0,0 +1,6 @@
from setuptools import setup, find_packages
setup(
name="picotron", # Name of the package
version='0.1.0',
packages=find_packages(), # Automatically find packages in the current directory
)

View File

@ -8,7 +8,7 @@ from src.distributed.distributed_primtives import ContextComms
from model import Attention
import src.distributed.process_group_manager as pgm
from parallel.base_parallel import BaseParallel
from src.parallel.base_parallel import BaseParallel
class ContextParallel(BaseParallel):
def __init__(self, model, config):

View File

@ -3,7 +3,7 @@ from src.distributed.distributed_primtives import pipeline_communicate, bidirect
import torch, torch.nn as nn, torch.nn.functional as F
import torch.distributed as dist
from parallel.base_parallel import BaseParallel
from src.parallel.base_parallel import BaseParallel
class PipelineParallel(BaseParallel):
def __init__(self, model, config):

View File

@ -1 +0,0 @@
#TODO

View File

View File

@ -0,0 +1,519 @@
"""
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from src.parallel.tensor_parallel.utils import VocabUtility, divide_and_check_no_remainder
import torch
import torch.nn.init as init
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from typing import Callable, Optional
import src.distributed.process_group_manager as pgm
from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
# def _initialize_affine_weight(
# weight: torch.Tensor,
# out_features: int,
# in_features: int,
# per_partition_size: int,
# partition_dim: int,
# init_method: Callable[[torch.Tensor], torch.Tensor]
# ) -> Optional[torch.Tensor]:
# """
# Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight
# Args:
# weight: The weight tensor that will be initialized for the current partition.
# out_features: second dimension of weight matrix W.
# in_features: first dimension of weight matrix W.
# per_partition_size: The size of the weight partition assigned to each process.
# partition_dim: The dimension along which the weight matrix is split for parallelism.
# init_method: The method used to initialize the weight values.
# """
# # If we only use 1 process for model parallelism, we can simply initialize the weight
# if pgm.process_group_manager.tp_world_size == 1:
# init_method(weight)
# return None
# # Initialize master weight
# master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
# init_method(master_weight)
# # Split the model into size of per_partition_size and take the corresponding partition
# weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim)
# weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous()
# return None
# class ColumnParallelLinear(torch.nn.Module):
# """Column Parallel Linear layer
# Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
# This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
# Arguments:
# in_features: first dimension of weight matrix W.
# out_features: second dimension of weight matrix W.
# bias: If true, add bias
# init_method: method to initialize weights
# gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
# """
# def __init__(
# self,
# in_features: int,
# out_features: int,
# bias: bool = False,
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
# gather_output: bool = False,
# ) -> None:
# super(ColumnParallelLinear, self).__init__()
# self.in_features = in_features
# self.out_features = out_features
# assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
# self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
# self.gather_output = gather_output
# # Allocate space for the weight and bias
# # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
# self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
# if bias:
# self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# # Always initialize bias to zero.
# with torch.no_grad():
# self.bias.zero_()
# else:
# self.register_parameter("bias", None)
# # Initialize weight.
# _initialize_affine_weight(
# self.weight,
# self.out_features,
# self.in_features,
# self.output_size_per_partition,
# partition_dim = 0,
# init_method = init_method,
# )
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
# input_parallel = copy_to_model_parallel_region(input_)
# output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
# if self.gather_output:
# output = gather_from_model_parallel_region(output)
# return output
# class RowParallelLinear(torch.nn.Module):
# """Linear layer with row parallelism.
# Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
# - -
# | W_1 |
# | . |
# W = | . | X = [X_1, ..., X_p]
# | . |
# | W_p |
# - -
# We assume that X is already parallelized. This is the case after ColumnParallelLinear.
# This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
# Arguments:
# in_features: first dimension of matrix W.
# out_features: second dimension of matrix W.
# bias: If true, add bias
# init_method: method to initialize weights.
# """
# def __init__(
# self,
# in_features: int,
# out_features: int,
# bias: bool = True,
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
# ):
# super(RowParallelLinear, self).__init__()
# # Keep input parameters
# self.in_features = in_features
# self.out_features = out_features
# self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size
# self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
# if bias:
# self.bias = Parameter(torch.Tensor(self.out_features))
# # Always initialize bias to zero.
# with torch.no_grad():
# self.bias.zero_()
# else:
# self.register_parameter("bias", None)
# # Initialize weight.
# _initialize_affine_weight(
# self.weight,
# self.out_features,
# self.in_features,
# self.input_size_per_partition,
# partition_dim = 1,
# init_method = init_method,
# )
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
# output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b
# # All-reduce across all the partitions.
# output_ = reduce_from_model_parallel_region(output_parallel)
# if self.bias is not None:
# output = output_ + self.bias
# else:
# output = output_
# return output
# class VocabParallelEmbedding(torch.nn.Module):
# """Embedding parallelized in the vocabulary dimension.
# This is mainly adapted from torch.nn.Embedding and all the default values are kept.
# Arguments:
# num_embeddings: vocabulary size.
# embedding_dim: size of hidden state.
# init_method: method to initialize weights.
# """
# def __init__(
# self,
# num_embeddings: int,
# embedding_dim: int,
# padding_idx: Optional[int] = None,
# max_norm: Optional[float] = None,
# norm_type: float = 2.0,
# scale_grad_by_freq: bool = False,
# sparse: bool = False,
# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
# ) -> None:
# super(VocabParallelEmbedding, self).__init__()
# # Keep the input dimensions.
# self.num_embeddings = num_embeddings
# self.embedding_dim = embedding_dim
# self.padding_idx = padding_idx
# self.max_norm = max_norm
# self.norm_type = norm_type
# self.scale_grad_by_freq = scale_grad_by_freq
# self.sparse = sparse
# self._weight = None
# # Divide the weight matrix along the vocaburaly dimension.
# self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
# self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
# )
# self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# # Allocate weights.
# self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
# # And initialize.
# _initialize_affine_weight(
# self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
# )
# def forward(self, input_: torch.Tensor) -> torch.Tensor:
# """
# Performs an embedding lookup for input tokens in the parallelized embedding layer
# 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
# 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
# 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
# """
# # Build the mask for out-of-vocabulary tokens.
# input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# # Mask the input.
# masked_input = input_.clone() - self.vocab_start_index
# masked_input[input_mask] = 0
# # Get the embeddings for the valid tokens.
# output_parallel = F.embedding(
# masked_input,
# self.weight,
# self.padding_idx,
# self.max_norm,
# self.norm_type,
# self.scale_grad_by_freq,
# self.sparse,
# )
# # Embedding of out-of-vocabulary tokens is set to 0.
# output_parallel[input_mask, :] = 0.0
# # Reduce across all the model parallel GPUs to get the final output.
# output = reduce_from_model_parallel_region(output_parallel)
# return output
def _initialize_affine_weight(
weight: torch.Tensor,
out_features: int,
in_features: int,
per_partition_size: int,
partition_dim: int,
init_method: Callable[[torch.Tensor], torch.Tensor],
stride: int = 1,
return_master_weight: bool = False,
) -> Optional[torch.Tensor]:
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
# If we only use 1 process for model parallelism, bypass scatter.
world_size = pgm.process_group_manager.world_size
if world_size == 1:
init_method(weight)
if return_master_weight:
return weight
return None
# Initialize master weight
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
# init_method(master_weight)
k = 1.0 / in_features
bound = torch.sqrt(torch.tensor(k, dtype=master_weight.dtype))
# Use PyTorch's built-in uniform initialization
init.uniform_(master_weight, -bound.item(), bound.item())
# Split and copy
per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
rank = pgm.process_group_manager.tp_rank
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
gather_output: bool = True,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
) -> None:
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = pgm.process_group_manager.tp_world_size
self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features))
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Backprop: all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
return output
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
input_is_parallel: bool = True, # Normally, input is parallelized, especially in Attention projection/MLP. There is a column parallel before it.
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = pgm.process_group_manager.tp_world_size
self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.input_size_per_partition,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce.
input_parallel = input_
# Matrix multiply
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
if self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
) -> None:
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self._weight = None
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Mask the output embedding.
# Embedding of tokens that are not in the vocabulary is set to 0. do a all reduce at the end.
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
return output

View File

@ -0,0 +1,96 @@
"""
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from src.parallel.tensor_parallel.utils import split_tensor_along_last_dim
import torch.distributed as dist
import torch
import src.distributed.process_group_manager as pgm
def _reduce(input_):
"""All-reduce the input tensor across model parallel(Tensor Parallel) group."""
# Bypass the function if we are using only 1 GPU.
if pgm.process_group_manager.tp_size == 1:
return input_
# All-reduce across the tensor parallel group
dist.all_reduce(input_, group=pgm.process_group_manager.tp_group)
return input_
class _CopyToModelParallelRegion(torch.autograd.Function):
"""copy(identity) in forward pass, all reduce in backward pass"""
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""all reduce in forward pass, copy(identity) in backward pass"""
@staticmethod
def forward(ctx, input_): # type: ignore
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
return grad_output
# This is the `f` function in the paper: https://arxiv.org/abs/1909.08053
def copy_to_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _CopyToModelParallelRegion.apply(input_)
# This is the `g` function in the paper, which is the conjugate of `f`
def reduce_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ReduceFromModelParallelRegion.apply(input_)
def _split(input_: torch.Tensor) -> torch.Tensor:
"""Split the tensor along its last dimension and keep the corresponding slice."""
tp_rank = pgm.process_group_manager.tp_rank
tp_world_size = pgm.process_group_manager.tp_size
# Bypass the function if we are using only 1 GPU
if tp_world_size == 1:
return input_
# Split along last dimension and keep the corresponding slice
input_list = split_tensor_along_last_dim(input_, tp_world_size)
output = input_list[tp_rank].contiguous()
return output
def _gather(input_: torch.Tensor) -> torch.Tensor:
"""Gather tensors and concatinate along the last dimension."""
tp_rank = pgm.process_group_manager.tp_rank
tp_world_size = pgm.process_group_manager.tp_size
# Bypass the function if we are using only 1 GPU.
if tp_world_size == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
tensor_list = [torch.empty_like(input_) for _ in range(tp_world_size)]
tensor_list[tp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=pgm.process_group_manager.tp_group)
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gathher in the forward pass, split in the backward pass."""
@staticmethod
def forward(ctx, input_):
return _gather(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _GatherFromModelParallelRegion.apply(input_)

View File

@ -0,0 +1,51 @@
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
import torch.nn.init as init
import torch.nn as nn
class TensorParallel(nn.Module):
def __init__(self, model, init_method = init.xavier_normal_):
module_linear_name_stype_mapping_list = [
("attention", "qkv_proj", "column"),
("attention", "out_proj", "row"),
("mlp", "gate_up_proj", "column"),
("mlp", "down_proj", "row"),
]
self.init_method = init_method
for layer in model.decoder_layers:
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
self.replace_module(getattr(layer, module_name), linear_proj_name, style)
self.replace_module(model, "embedding", "vocab")
self.replace_module(model, "final_proj", "column", args={"gather_output": True})
# for name, param in model.named_parameters():
# print(name, param.shape, param.requires_grad)
def replace_module(self,module, linear_proj_name, style, args = {}):
assert style in ["column", "row", 'vocab']
linear_layer = getattr(module, linear_proj_name)
if style == "column":
new_linear_layer = ColumnParallelLinear(
in_features=linear_layer.in_features,
out_features=linear_layer.out_features,
bias=linear_layer.bias is not None,
init_method=self.init_method,
gather_output=args.get("gather_output", False)
)
elif style == "row":
new_linear_layer = RowParallelLinear(
in_features=linear_layer.in_features,
out_features=linear_layer.out_features,
bias=linear_layer.bias is not None,
init_method=self.init_method
)
else:
new_linear_layer = VocabParallelEmbedding(
num_embeddings=linear_layer.num_embeddings,
embedding_dim=linear_layer.embedding_dim,
init_method=self.init_method
)
setattr(module, linear_proj_name, new_linear_layer)

View File

@ -0,0 +1,47 @@
"""
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from typing import Tuple
import torch
def divide_and_check_no_remainder(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor, num_partitions: int
) -> Tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide_and_check_no_remainder(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
return tensor_list
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [first, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank: int
) -> Tuple[int, int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Tuple[int, int]:
per_partition_vocab_size = divide_and_check_no_remainder(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)

View File

@ -1,8 +1,9 @@
"""Training script for LLaMA model.
torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py
torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --tp_size 2
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 2
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 1 --dp_size 2
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --tp_size 2
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
"""
@ -19,25 +20,26 @@ from datasets import load_dataset
import argparse
from datasets import Features, Sequence, Value
import numpy as np
from src.parallel.tensor_parallel.tensor_parallel import TensorParallel
import src.distributed.process_group_manager as pgm
from utils import set_all_seed, print
from src.distributed.process_group_manager import setup_process_group_manager
from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from dataclasses import dataclass
from src.parallel.data_parallel.data_parallel_bucket import DataParallel
from src.parallel.context_parallel import ContextParallel
# from src.parallel.context_parallel import ContextParallel
from model import LLaMA
import wandb
import multiprocessing
class MicroBatchDataLoader(DataLoader):
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None, num_workers=0):
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, grad_acc = 1, split="train", num_samples=None, num_workers=0):
self.global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size # each DP rank gets a local batch
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
self.grad_acc = grad_acc
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
@ -137,17 +139,19 @@ def train_step(model, data_loader, device):
input_ids = batch["input_ids"].to(device)
target_ids = batch["target_ids"].to(device)
outputs = model(input_ids=input_ids)
for i in range(data_loader.grad_acc):
outputs = model(input_ids=input_ids)
# compute the loss
batch_size, seq_len = input_ids.shape
target_ids = target_ids.reshape(-1)
outputs = outputs.view(seq_len*batch_size, -1)
loss = F.cross_entropy(outputs, target_ids, reduction='mean')
loss.backward()
# compute the loss
batch_size, seq_len = input_ids.shape
target_ids = target_ids.reshape(-1)
outputs = outputs.view(seq_len*batch_size, -1)
loss = F.cross_entropy(outputs, target_ids, reduction='mean')
loss.backward()
acc_loss += loss.item()
acc_loss += loss.item()
acc_loss /= data_loader.grad_acc
return acc_loss
@ -175,6 +179,7 @@ if __name__ == "__main__":
# SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
## hyperparameters
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 16, 4, 3e-4, 100000, int(10e8), 42
grad_acc = 16
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
@ -197,10 +202,12 @@ if __name__ == "__main__":
dataset_name = "roneneldan/TinyStories"
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
config = AutoConfig.from_pretrained(model_name)
config.num_attention_heads = 16
config.num_key_value_heads = 4
model = LLaMA(
config=config
).to(device)
)
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.init(
@ -220,8 +227,11 @@ if __name__ == "__main__":
},
)
if pgm.process_group_manager.cp_size > 1:
model = ContextParallel(model, config)
if pgm.process_group_manager.tp_world_size > 1:
TensorParallel(model)
# if pgm.process_group_manager.cp_size > 1:
# model = ContextParallel(model, config)
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, config)
@ -229,17 +239,15 @@ if __name__ == "__main__":
if pgm.process_group_manager.dp_world_size > 1:
model = DataParallel(model, pgm.process_group_manager.dp_group)
# if pgm.process_group_manager.tp_world_size > 1:
# model = TensorParallel(model, config)
model.to(device)
model.train()
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES)
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, grad_acc = grad_acc, num_samples=NUM_SAMPLES)
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
trained_tokens, step = 0, 0
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN * grad_acc
dist.barrier()