From 73772387414e7449fb495e004d046baf4ac4833d Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Fri, 18 Oct 2024 05:13:44 +0000 Subject: [PATCH] tesnsor parallel, will clean later --- .gitignore | 5 +- model.py | 15 +- setup.py | 6 + src/parallel/context_parallel.py | 2 +- src/parallel/pipeline_parallel.py | 2 +- src/parallel/tensor_parallel.py | 1 - src/parallel/tensor_parallel/__init__.py | 0 src/parallel/tensor_parallel/layers.py | 519 ++++++++++++++++++ src/parallel/tensor_parallel/mappings.py | 96 ++++ .../tensor_parallel/tensor_parallel.py | 51 ++ src/parallel/tensor_parallel/utils.py | 47 ++ train.py | 52 +- 12 files changed, 764 insertions(+), 32 deletions(-) create mode 100644 setup.py delete mode 100644 src/parallel/tensor_parallel.py create mode 100644 src/parallel/tensor_parallel/__init__.py create mode 100644 src/parallel/tensor_parallel/layers.py create mode 100644 src/parallel/tensor_parallel/mappings.py create mode 100644 src/parallel/tensor_parallel/tensor_parallel.py create mode 100644 src/parallel/tensor_parallel/utils.py diff --git a/.gitignore b/.gitignore index 54e505a..35c90d0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ __pycache__ *.pth -.vscode/ \ No newline at end of file +.vscode/ +picotron.egg-info +*.ipynb +wandb \ No newline at end of file diff --git a/model.py b/model.py index a4c28e0..3969cbf 100644 --- a/model.py +++ b/model.py @@ -5,6 +5,8 @@ import torch.nn.functional as F import torch.nn.init as init from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.layers.rotary import apply_rotary_emb +import src.distributed.process_group_manager as pgm +from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = torch.bfloat16 if os.getenv('DATA_TYPE', 'bfloat16') == 'bfloat16' else torch.float32 @@ -56,13 +58,13 @@ class CausalSelfAttention(nn.Module): self.num_heads = config.num_attention_heads self.num_key_values = config.num_key_value_heads self.head_dim = self.hidden_size//self.num_heads - # model_parallel_size = get_model_parallel_world_size() - model_parallel_size = 1 + model_parallel_size = pgm.process_group_manager.tp_world_size self.num_local_heads = config.num_attention_heads // model_parallel_size # TP parallelism self.num_local_kv_heads = config.num_key_value_heads // model_parallel_size # TP parallelism self.is_merged_qkv_weight = os.getenv('MERGED_QKV_WEIGHT', '1') if self.is_merged_qkv_weight == '1': self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim + 2*self.num_key_values*self.head_dim, bias=False) + # self.qkv_proj = ColumnParallelLinear(config.hidden_size, self.num_heads*self.head_dim + 2*self.num_key_values*self.head_dim, bias=False, gather_output=False, init_method=init_method) else: self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) @@ -134,11 +136,12 @@ class LLaMAMLP(nn.Module): self.merged_gate_up = os.getenv('MERGED_GATE_UP_WEIGHT', '1') == '1' if self.merged_gate_up: self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size*2, bias=False) + # self.gate_up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size*2, bias=False, gather_output=False, init_method=init_method) else: - self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) - self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) - # self.up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method) - # self.gate_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method) + # self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + # self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method) + self.gate_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) # self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False, input_is_parallel=True, init_method=init_method) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..11dbf3e --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +from setuptools import setup, find_packages +setup( + name="picotron", # Name of the package + version='0.1.0', + packages=find_packages(), # Automatically find packages in the current directory +) \ No newline at end of file diff --git a/src/parallel/context_parallel.py b/src/parallel/context_parallel.py index 06824bd..033cfb2 100644 --- a/src/parallel/context_parallel.py +++ b/src/parallel/context_parallel.py @@ -8,7 +8,7 @@ from src.distributed.distributed_primtives import ContextComms from model import Attention import src.distributed.process_group_manager as pgm -from parallel.base_parallel import BaseParallel +from src.parallel.base_parallel import BaseParallel class ContextParallel(BaseParallel): def __init__(self, model, config): diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index 9868016..b99407c 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -3,7 +3,7 @@ from src.distributed.distributed_primtives import pipeline_communicate, bidirect import torch, torch.nn as nn, torch.nn.functional as F import torch.distributed as dist -from parallel.base_parallel import BaseParallel +from src.parallel.base_parallel import BaseParallel class PipelineParallel(BaseParallel): def __init__(self, model, config): diff --git a/src/parallel/tensor_parallel.py b/src/parallel/tensor_parallel.py deleted file mode 100644 index 503fa1d..0000000 --- a/src/parallel/tensor_parallel.py +++ /dev/null @@ -1 +0,0 @@ -#TODO \ No newline at end of file diff --git a/src/parallel/tensor_parallel/__init__.py b/src/parallel/tensor_parallel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/parallel/tensor_parallel/layers.py b/src/parallel/tensor_parallel/layers.py new file mode 100644 index 0000000..a892b45 --- /dev/null +++ b/src/parallel/tensor_parallel/layers.py @@ -0,0 +1,519 @@ +""" +Inspired by Fair Scale/Megatron's Tensor Parallelism implementation +Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale +""" +from src.parallel.tensor_parallel.utils import VocabUtility, divide_and_check_no_remainder +import torch +import torch.nn.init as init +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from typing import Callable, Optional +import src.distributed.process_group_manager as pgm +from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region + +# def _initialize_affine_weight( +# weight: torch.Tensor, +# out_features: int, +# in_features: int, +# per_partition_size: int, +# partition_dim: int, +# init_method: Callable[[torch.Tensor], torch.Tensor] +# ) -> Optional[torch.Tensor]: +# """ +# Initialize the master weights for the entire linear layer. Each process will take a partition of the master weight +# Args: +# weight: The weight tensor that will be initialized for the current partition. +# out_features: second dimension of weight matrix W. +# in_features: first dimension of weight matrix W. +# per_partition_size: The size of the weight partition assigned to each process. +# partition_dim: The dimension along which the weight matrix is split for parallelism. +# init_method: The method used to initialize the weight values. +# """ + +# # If we only use 1 process for model parallelism, we can simply initialize the weight +# if pgm.process_group_manager.tp_world_size == 1: +# init_method(weight) +# return None + +# # Initialize master weight +# master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False) +# init_method(master_weight) + +# # Split the model into size of per_partition_size and take the corresponding partition +# weight_list = torch.split(master_weight, per_partition_size, dim=partition_dim) +# weight.data = weight_list[pgm.process_group_manager.tp_rank].contiguous() + +# return None + +# class ColumnParallelLinear(torch.nn.Module): +# """Column Parallel Linear layer +# Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p] +# This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension. +# Arguments: +# in_features: first dimension of weight matrix W. +# out_features: second dimension of weight matrix W. +# bias: If true, add bias +# init_method: method to initialize weights +# gather_output: If true, gather the output from all the partitions. This is used for the last linear layer +# """ + +# def __init__( +# self, +# in_features: int, +# out_features: int, +# bias: bool = False, +# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, +# gather_output: bool = False, +# ) -> None: +# super(ColumnParallelLinear, self).__init__() + +# self.in_features = in_features +# self.out_features = out_features +# assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size" +# self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size +# self.gather_output = gather_output + +# # Allocate space for the weight and bias +# # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions +# self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i +# if bias: +# self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) +# # Always initialize bias to zero. +# with torch.no_grad(): +# self.bias.zero_() +# else: +# self.register_parameter("bias", None) + +# # Initialize weight. +# _initialize_affine_weight( +# self.weight, +# self.out_features, +# self.in_features, +# self.output_size_per_partition, +# partition_dim = 0, +# init_method = init_method, +# ) + +# def forward(self, input_: torch.Tensor) -> torch.Tensor: +# input_parallel = copy_to_model_parallel_region(input_) +# output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i +# if self.gather_output: +# output = gather_from_model_parallel_region(output) +# return output + +# class RowParallelLinear(torch.nn.Module): +# """Linear layer with row parallelism. +# Y = XW + b. W is parallelized along its first dimension and X along its second dimension as: +# - - +# | W_1 | +# | . | +# W = | . | X = [X_1, ..., X_p] +# | . | +# | W_p | +# - - +# We assume that X is already parallelized. This is the case after ColumnParallelLinear. +# This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method. +# Arguments: +# in_features: first dimension of matrix W. +# out_features: second dimension of matrix W. +# bias: If true, add bias +# init_method: method to initialize weights. +# """ + +# def __init__( +# self, +# in_features: int, +# out_features: int, +# bias: bool = True, +# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, +# ): +# super(RowParallelLinear, self).__init__() + +# # Keep input parameters +# self.in_features = in_features +# self.out_features = out_features +# self.input_size_per_partition = in_features // pgm.process_group_manager.tp_world_size + +# self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) +# if bias: +# self.bias = Parameter(torch.Tensor(self.out_features)) +# # Always initialize bias to zero. +# with torch.no_grad(): +# self.bias.zero_() +# else: +# self.register_parameter("bias", None) + +# # Initialize weight. +# _initialize_affine_weight( +# self.weight, +# self.out_features, +# self.in_features, +# self.input_size_per_partition, +# partition_dim = 1, +# init_method = init_method, +# ) + +# def forward(self, input_: torch.Tensor) -> torch.Tensor: +# output_parallel = F.linear(input_, self.weight) # X_i * W_i^T + b +# # All-reduce across all the partitions. +# output_ = reduce_from_model_parallel_region(output_parallel) +# if self.bias is not None: +# output = output_ + self.bias +# else: +# output = output_ +# return output + +# class VocabParallelEmbedding(torch.nn.Module): +# """Embedding parallelized in the vocabulary dimension. +# This is mainly adapted from torch.nn.Embedding and all the default values are kept. +# Arguments: +# num_embeddings: vocabulary size. +# embedding_dim: size of hidden state. +# init_method: method to initialize weights. +# """ + +# def __init__( +# self, +# num_embeddings: int, +# embedding_dim: int, +# padding_idx: Optional[int] = None, +# max_norm: Optional[float] = None, +# norm_type: float = 2.0, +# scale_grad_by_freq: bool = False, +# sparse: bool = False, +# init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, +# ) -> None: +# super(VocabParallelEmbedding, self).__init__() +# # Keep the input dimensions. +# self.num_embeddings = num_embeddings +# self.embedding_dim = embedding_dim +# self.padding_idx = padding_idx +# self.max_norm = max_norm +# self.norm_type = norm_type +# self.scale_grad_by_freq = scale_grad_by_freq +# self.sparse = sparse +# self._weight = None +# # Divide the weight matrix along the vocaburaly dimension. +# self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( +# self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size +# ) +# self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index + +# # Allocate weights. +# self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) +# # And initialize. +# _initialize_affine_weight( +# self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method +# ) + +# def forward(self, input_: torch.Tensor) -> torch.Tensor: +# """ +# Performs an embedding lookup for input tokens in the parallelized embedding layer +# 1. Masks tokens that fall outside the specified vocabulary range and adjusts the input +# 2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero +# 3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization +# """ +# # Build the mask for out-of-vocabulary tokens. +# input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) +# # Mask the input. +# masked_input = input_.clone() - self.vocab_start_index +# masked_input[input_mask] = 0 +# # Get the embeddings for the valid tokens. +# output_parallel = F.embedding( +# masked_input, +# self.weight, +# self.padding_idx, +# self.max_norm, +# self.norm_type, +# self.scale_grad_by_freq, +# self.sparse, +# ) +# # Embedding of out-of-vocabulary tokens is set to 0. +# output_parallel[input_mask, :] = 0.0 +# # Reduce across all the model parallel GPUs to get the final output. +# output = reduce_from_model_parallel_region(output_parallel) +# return output + + +def _initialize_affine_weight( + weight: torch.Tensor, + out_features: int, + in_features: int, + per_partition_size: int, + partition_dim: int, + init_method: Callable[[torch.Tensor], torch.Tensor], + stride: int = 1, + return_master_weight: bool = False, +) -> Optional[torch.Tensor]: + """Initialize affine weight for model parallel. + + Build the master weight on all processes and scatter + the relevant chunk.""" + + # If we only use 1 process for model parallelism, bypass scatter. + world_size = pgm.process_group_manager.world_size + if world_size == 1: + init_method(weight) + if return_master_weight: + return weight + return None + + # Initialize master weight + master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False) + # init_method(master_weight) + + k = 1.0 / in_features + bound = torch.sqrt(torch.tensor(k, dtype=master_weight.dtype)) + + # Use PyTorch's built-in uniform initialization + init.uniform_(master_weight, -bound.item(), bound.item()) + + # Split and copy + per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride) + weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) + rank = pgm.process_group_manager.tp_rank + my_weight_list = weight_list[rank::world_size] + + with torch.no_grad(): + torch.cat(my_weight_list, dim=partition_dim, out=weight) + if return_master_weight: + return master_weight + return None + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + in_features: first dimension of matrix A. + out_features: second dimension of matrix A. + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y avaiable + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + gather_output: bool = True, + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, + stride: int = 1, + keep_master_weight_for_test: bool = False, + ) -> None: + super(ColumnParallelLinear, self).__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = pgm.process_group_manager.tp_world_size + self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size) + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) + if bias: + self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter("bias", None) + + # Initialize weight. + self.master_weight = _initialize_affine_weight( + self.weight, + self.out_features, + self.in_features, + self.output_size_per_partition, + 0, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + ) + + def get_master_weight(self) -> torch.Tensor: + return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore + # Backprop: all-reduce. + input_parallel = copy_to_model_parallel_region(input_) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_from_model_parallel_region(output_parallel) + else: + output = output_parallel + return output + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + in_features: first dimension of matrix A. + out_features: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + input_is_parallel: bool = True, # Normally, input is parallelized, especially in Attention projection/MLP. There is a column parallel before it. + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, + stride: int = 1, + keep_master_weight_for_test: bool = False, + ): + super(RowParallelLinear, self).__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.input_is_parallel = input_is_parallel + # Divide the weight matrix along the last dimension. + world_size = pgm.process_group_manager.tp_world_size + self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size) + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) + if bias: + self.bias = Parameter(torch.Tensor(self.out_features)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter("bias", None) + + # Initialize weight. + self.master_weight = _initialize_affine_weight( + self.weight, + self.out_features, + self.in_features, + self.input_size_per_partition, + 1, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + ) + + def get_master_weight(self) -> torch.Tensor: + return gather_from_model_parallel_region(self.weight.data) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore + # Set up backprop all-reduce. + + input_parallel = input_ + + # Matrix multiply + output_parallel = F.linear(input_parallel, self.weight) + # All-reduce across all the partitions. + output_ = reduce_from_model_parallel_region(output_parallel) + if self.bias is not None: + output = output_ + self.bias + else: + output = output_ + return output + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, + ) -> None: + super(VocabParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + self._weight = None + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index + + # Allocate weights. + self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) + # And initialize. + _initialize_affine_weight( + self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + # Get the embeddings. + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + # Mask the output embedding. + # Embedding of tokens that are not in the vocabulary is set to 0. do a all reduce at the end. + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_from_model_parallel_region(output_parallel) + return output \ No newline at end of file diff --git a/src/parallel/tensor_parallel/mappings.py b/src/parallel/tensor_parallel/mappings.py new file mode 100644 index 0000000..f0e4aed --- /dev/null +++ b/src/parallel/tensor_parallel/mappings.py @@ -0,0 +1,96 @@ +""" +Inspired by Fair Scale/Megatron's Tensor Parallelism implementation +Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale +""" +from src.parallel.tensor_parallel.utils import split_tensor_along_last_dim +import torch.distributed as dist +import torch +import src.distributed.process_group_manager as pgm + +def _reduce(input_): + """All-reduce the input tensor across model parallel(Tensor Parallel) group.""" + # Bypass the function if we are using only 1 GPU. + if pgm.process_group_manager.tp_size == 1: + return input_ + + # All-reduce across the tensor parallel group + dist.all_reduce(input_, group=pgm.process_group_manager.tp_group) + + return input_ + +class _CopyToModelParallelRegion(torch.autograd.Function): + """copy(identity) in forward pass, all reduce in backward pass""" + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """all reduce in forward pass, copy(identity) in backward pass""" + @staticmethod + def forward(ctx, input_): # type: ignore + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): # type: ignore + return grad_output + +# This is the `f` function in the paper: https://arxiv.org/abs/1909.08053 +def copy_to_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: + return _CopyToModelParallelRegion.apply(input_) + +# This is the `g` function in the paper, which is the conjugate of `f` +def reduce_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: + return _ReduceFromModelParallelRegion.apply(input_) + +def _split(input_: torch.Tensor) -> torch.Tensor: + """Split the tensor along its last dimension and keep the corresponding slice.""" + tp_rank = pgm.process_group_manager.tp_rank + tp_world_size = pgm.process_group_manager.tp_size + + # Bypass the function if we are using only 1 GPU + if tp_world_size == 1: + return input_ + + # Split along last dimension and keep the corresponding slice + input_list = split_tensor_along_last_dim(input_, tp_world_size) + output = input_list[tp_rank].contiguous() + + return output + +def _gather(input_: torch.Tensor) -> torch.Tensor: + """Gather tensors and concatinate along the last dimension.""" + tp_rank = pgm.process_group_manager.tp_rank + tp_world_size = pgm.process_group_manager.tp_size + + # Bypass the function if we are using only 1 GPU. + if tp_world_size == 1: + return input_ + + # Size and dimension. + last_dim = input_.dim() - 1 + + tensor_list = [torch.empty_like(input_) for _ in range(tp_world_size)] + tensor_list[tp_rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=pgm.process_group_manager.tp_group) + + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gathher in the forward pass, split in the backward pass.""" + + @staticmethod + def forward(ctx, input_): + return _gather(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + +def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: + return _GatherFromModelParallelRegion.apply(input_) \ No newline at end of file diff --git a/src/parallel/tensor_parallel/tensor_parallel.py b/src/parallel/tensor_parallel/tensor_parallel.py new file mode 100644 index 0000000..7f14a47 --- /dev/null +++ b/src/parallel/tensor_parallel/tensor_parallel.py @@ -0,0 +1,51 @@ +from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +import torch.nn.init as init +import torch.nn as nn + +class TensorParallel(nn.Module): + def __init__(self, model, init_method = init.xavier_normal_): + module_linear_name_stype_mapping_list = [ + ("attention", "qkv_proj", "column"), + ("attention", "out_proj", "row"), + ("mlp", "gate_up_proj", "column"), + ("mlp", "down_proj", "row"), + ] + + self.init_method = init_method + + for layer in model.decoder_layers: + for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list: + self.replace_module(getattr(layer, module_name), linear_proj_name, style) + self.replace_module(model, "embedding", "vocab") + self.replace_module(model, "final_proj", "column", args={"gather_output": True}) + + # for name, param in model.named_parameters(): + # print(name, param.shape, param.requires_grad) + + + def replace_module(self,module, linear_proj_name, style, args = {}): + assert style in ["column", "row", 'vocab'] + linear_layer = getattr(module, linear_proj_name) + if style == "column": + new_linear_layer = ColumnParallelLinear( + in_features=linear_layer.in_features, + out_features=linear_layer.out_features, + bias=linear_layer.bias is not None, + init_method=self.init_method, + gather_output=args.get("gather_output", False) + ) + elif style == "row": + new_linear_layer = RowParallelLinear( + in_features=linear_layer.in_features, + out_features=linear_layer.out_features, + bias=linear_layer.bias is not None, + init_method=self.init_method + ) + else: + new_linear_layer = VocabParallelEmbedding( + num_embeddings=linear_layer.num_embeddings, + embedding_dim=linear_layer.embedding_dim, + init_method=self.init_method + ) + setattr(module, linear_proj_name, new_linear_layer) + \ No newline at end of file diff --git a/src/parallel/tensor_parallel/utils.py b/src/parallel/tensor_parallel/utils.py new file mode 100644 index 0000000..f108f75 --- /dev/null +++ b/src/parallel/tensor_parallel/utils.py @@ -0,0 +1,47 @@ +""" +Inspired by Fair Scale/Megatron's Tensor Parallelism implementation +Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale +""" +from typing import Tuple +import torch + +def divide_and_check_no_remainder(numerator: int, denominator: int) -> int: + """Ensure that numerator is divisible by the denominator and return + the division value.""" + assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" + return numerator // denominator + +def split_tensor_along_last_dim( + tensor: torch.Tensor, num_partitions: int +) -> Tuple[torch.Tensor, ...]: + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide_and_check_no_remainder(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + return tensor_list + +class VocabUtility: + """Split the vocabulary into `world_size` chunks amd return the + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [first, last)""" + + @staticmethod + def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, rank: int + ) -> Tuple[int, int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Tuple[int, int]: + per_partition_vocab_size = divide_and_check_no_remainder(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) \ No newline at end of file diff --git a/train.py b/train.py index dd7f948..a8b5d7d 100644 --- a/train.py +++ b/train.py @@ -1,8 +1,9 @@ """Training script for LLaMA model. -torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py +torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb +torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --tp_size 2 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 2 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 1 --dp_size 2 -CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py +CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --tp_size 2 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 """ @@ -19,25 +20,26 @@ from datasets import load_dataset import argparse from datasets import Features, Sequence, Value import numpy as np +from src.parallel.tensor_parallel.tensor_parallel import TensorParallel import src.distributed.process_group_manager as pgm from utils import set_all_seed, print from src.distributed.process_group_manager import setup_process_group_manager from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel -from dataclasses import dataclass from src.parallel.data_parallel.data_parallel_bucket import DataParallel -from src.parallel.context_parallel import ContextParallel +# from src.parallel.context_parallel import ContextParallel from model import LLaMA import wandb import multiprocessing class MicroBatchDataLoader(DataLoader): - def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None, num_workers=0): + def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, grad_acc = 1, split="train", num_samples=None, num_workers=0): self.global_batch_size = global_batch_size self.micro_batch_size = micro_batch_size self.seq_length = seq_length self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size # each DP rank gets a local batch self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + self.grad_acc = grad_acc self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size @@ -137,17 +139,19 @@ def train_step(model, data_loader, device): input_ids = batch["input_ids"].to(device) target_ids = batch["target_ids"].to(device) - outputs = model(input_ids=input_ids) + for i in range(data_loader.grad_acc): + outputs = model(input_ids=input_ids) - # compute the loss - batch_size, seq_len = input_ids.shape - target_ids = target_ids.reshape(-1) - outputs = outputs.view(seq_len*batch_size, -1) - loss = F.cross_entropy(outputs, target_ids, reduction='mean') - - loss.backward() + # compute the loss + batch_size, seq_len = input_ids.shape + target_ids = target_ids.reshape(-1) + outputs = outputs.view(seq_len*batch_size, -1) + loss = F.cross_entropy(outputs, target_ids, reduction='mean') + + loss.backward() - acc_loss += loss.item() + acc_loss += loss.item() + acc_loss /= data_loader.grad_acc return acc_loss @@ -175,6 +179,7 @@ if __name__ == "__main__": # SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42 ## hyperparameters SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 16, 4, 3e-4, 100000, int(10e8), 42 + grad_acc = 16 assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" @@ -197,10 +202,12 @@ if __name__ == "__main__": dataset_name = "roneneldan/TinyStories" model_name = "HuggingFaceTB/SmolLM-360M-Instruct" config = AutoConfig.from_pretrained(model_name) + config.num_attention_heads = 16 + config.num_key_value_heads = 4 model = LLaMA( config=config - ).to(device) + ) if pgm.process_group_manager.global_rank == 0 and args.use_wandb: wandb.init( @@ -220,8 +227,11 @@ if __name__ == "__main__": }, ) - if pgm.process_group_manager.cp_size > 1: - model = ContextParallel(model, config) + if pgm.process_group_manager.tp_world_size > 1: + TensorParallel(model) + + # if pgm.process_group_manager.cp_size > 1: + # model = ContextParallel(model, config) if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, config) @@ -229,17 +239,15 @@ if __name__ == "__main__": if pgm.process_group_manager.dp_world_size > 1: model = DataParallel(model, pgm.process_group_manager.dp_group) - # if pgm.process_group_manager.tp_world_size > 1: - # model = TensorParallel(model, config) - + model.to(device) model.train() - data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES) + data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, grad_acc = grad_acc, num_samples=NUM_SAMPLES) tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) trained_tokens, step = 0, 0 - tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN + tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN * grad_acc dist.barrier()