2024-09-24 05:14:48 +08:00
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
2024-09-19 22:06:46 +08:00
import os
2024-09-25 21:17:05 +08:00
import torch . nn . functional as F
2024-09-19 22:06:46 +08:00
import torch , torch . distributed as dist
from torch . optim import AdamW
2024-09-25 21:17:05 +08:00
from transformers import AutoConfig , AutoModelForCausalLM
2024-09-25 20:36:22 +08:00
2024-09-23 18:28:01 +08:00
import argparse
2024-09-19 22:06:46 +08:00
2024-09-25 21:33:20 +08:00
import process_group_manager as pgm
2024-09-26 18:27:20 +08:00
from utils import set_all_seed , display_parallelism_grid , print
2024-09-25 21:33:20 +08:00
from process_group_manager import setup_process_group_manager
2024-09-25 21:17:05 +08:00
from pipeline_parallel import train_step_pipeline_1f1b , train_step_pipeline_afab , PipelineParallel
2024-09-25 20:36:22 +08:00
from data_parallel import DataParallel
2024-09-26 21:45:53 +08:00
from context_parallel import ContextParallel
2024-10-10 23:08:23 +08:00
from model import Llama
2024-09-25 21:17:05 +08:00
from dataset import MicroBatchDataLoader
2024-09-25 22:19:16 +08:00
import wandb
2024-09-25 21:17:05 +08:00
def train_step ( model , data_loader , device ) :
total_loss = 0.0
for _ in range ( data_loader . num_local_micro_batches ) :
batch = next ( iter ( data_loader ) )
2024-09-25 20:36:22 +08:00
2024-09-25 21:17:05 +08:00
input_ids = batch [ " input_ids " ] . to ( device )
position_ids = batch [ " position_index " ] . to ( device )
target_ids = batch [ " target_ids " ] . to ( device )
outputs = model ( input_ids = input_ids , position_ids = position_ids )
logits = outputs . logits
# Use your suggested cross_entropy calculation
loss = F . cross_entropy ( logits . transpose ( 1 , 2 ) , target_ids , reduction = ' mean ' )
loss . backward ( )
2024-09-25 20:36:22 +08:00
2024-09-25 21:17:05 +08:00
total_loss + = loss . item ( )
2024-09-19 22:06:46 +08:00
2024-09-25 21:17:05 +08:00
avg_loss = total_loss / data_loader . num_local_micro_batches
return avg_loss
2024-09-19 22:06:46 +08:00
2024-09-26 22:00:06 +08:00
def all_reduce_grads_across_dp_cp_ranks ( ) :
for param in model . parameters ( ) :
if param . grad is not None :
# Average the gradients across all DP & CP ranks
param . grad / = pgm . process_group_manager . cp_dp_world_size
dist . all_reduce ( param . grad , op = dist . ReduceOp . SUM , group = pgm . process_group_manager . cp_dp_group )
2024-09-19 22:06:46 +08:00
if __name__ == " __main__ " :
2024-09-23 18:28:01 +08:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --tp_size " , type = int , default = 1 )
2024-09-26 21:45:53 +08:00
parser . add_argument ( " --cp_size " , type = int , default = 1 )
2024-09-23 18:28:01 +08:00
parser . add_argument ( " --pp_size " , type = int , default = 1 )
parser . add_argument ( " --dp_size " , type = int , default = 1 )
2024-09-25 22:19:16 +08:00
parser . add_argument ( " --use_wandb " , action = " store_true " , default = False )
2024-09-26 18:27:20 +08:00
parser . add_argument ( " --use_cpu " , action = " store_true " , default = False )
parser . add_argument ( " --master_addr " , type = str , default = " localhost " )
parser . add_argument ( " --master_port " , type = int , default = 29500 )
2024-09-23 18:28:01 +08:00
args = parser . parse_args ( )
2024-09-26 18:27:20 +08:00
os . environ [ " OMP_NUM_THREADS " ] = " 1 "
2024-09-19 22:06:46 +08:00
os . environ [ " TOKENIZERS_PARALLELISM " ] = " false "
2024-09-26 18:27:20 +08:00
local_rank = int ( os . environ [ " LOCAL_RANK " ] )
world_size = int ( os . environ [ " WORLD_SIZE " ] )
host = os . environ [ " MASTER_ADDR " ]
port = int ( os . environ [ " MASTER_PORT " ] )
2024-09-25 22:19:16 +08:00
SEQ_LEN , GLOBAL_BATCH_SIZE , MICRO_BATCH_SIZE , LEARNING_RATE , NUM_SAMPLES , MAX_TOKENS , SEED = 10 , 6 , 2 , 1e-4 , 20 , 1800 , 42
2024-09-26 18:27:20 +08:00
backend = " gloo " if args . use_cpu else " nccl "
if backend == " nccl " :
torch . cuda . set_device ( local_rank )
device = torch . device ( " cuda " , local_rank )
else :
device = torch . device ( " cpu " )
dist . init_process_group ( rank = local_rank , world_size = world_size , backend = backend , init_method = f " tcp:// { host } : { port } " )
2024-09-26 21:45:53 +08:00
setup_process_group_manager ( tp_size = args . tp_size , cp_size = args . cp_size , pp_size = args . pp_size , dp_size = args . dp_size )
2024-09-23 18:28:01 +08:00
2024-09-26 21:45:53 +08:00
if pgm . process_group_manager . global_rank == 0 :
2024-09-24 21:43:22 +08:00
display_parallelism_grid ( )
2024-09-19 22:06:46 +08:00
2024-09-25 22:19:16 +08:00
set_all_seed ( SEED )
2024-09-25 20:36:22 +08:00
model_name = " HuggingFaceTB/SmolLM-360M-Instruct "
2024-09-25 22:19:16 +08:00
dataset_name = " roneneldan/TinyStories "
2024-09-25 20:36:22 +08:00
config = AutoConfig . from_pretrained ( model_name )
2024-09-25 22:19:16 +08:00
if pgm . process_group_manager . global_rank == 0 and args . use_wandb :
wandb . init (
project = " picotron " ,
name = f " test_convergence_ { pgm . process_group_manager } " ,
config = {
" tensor_parallel_size " : pgm . process_group_manager . tp_size ,
" pipeline_parallel_size " : pgm . process_group_manager . pp_size ,
2024-09-26 18:27:20 +08:00
" data_parallel_size " : pgm . process_group_manager . dp_size ,
2024-09-25 22:19:16 +08:00
" model " : model_name ,
" dataset " : dataset_name ,
" max_tokens " : MAX_TOKENS ,
" learning_rate " : LEARNING_RATE ,
" seed " : SEED ,
" micro_batch_size " : MICRO_BATCH_SIZE ,
" global_batch_size " : GLOBAL_BATCH_SIZE ,
} ,
)
2024-10-10 23:08:23 +08:00
#TODO: find a better way (should need to specify model_name + path to .pth)
model_name = " HuggingFaceTB/SmolLM-360M-Instruct "
config = AutoConfig . from_pretrained ( model_name )
model = Llama (
config = config ,
device = device ,
) . to ( device )
model . load_state_dict ( torch . load ( " smollm.pth " ) )
2024-09-25 20:36:22 +08:00
2024-09-26 21:45:53 +08:00
if pgm . process_group_manager . cp_size > 1 :
model = ContextParallel ( model , config ) . to ( device )
2024-09-25 21:33:20 +08:00
if pgm . process_group_manager . pp_world_size > 1 :
2024-09-25 21:17:05 +08:00
model = PipelineParallel ( model , config ) . to ( device )
2024-09-25 21:33:20 +08:00
if pgm . process_group_manager . dp_world_size > 1 :
2024-09-25 21:17:05 +08:00
model = DataParallel ( model , config ) . to ( device )
2024-09-25 20:36:22 +08:00
model . train ( )
2024-09-25 22:19:16 +08:00
data_loader = MicroBatchDataLoader ( GLOBAL_BATCH_SIZE , MICRO_BATCH_SIZE , SEQ_LEN , dataset_name , model_name , num_samples = NUM_SAMPLES )
2024-09-25 20:36:22 +08:00
tensor_shapes = ( SEQ_LEN , data_loader . micro_batch_size , config . hidden_size )
2024-09-19 22:06:46 +08:00
optimizer = AdamW ( model . parameters ( ) , lr = LEARNING_RATE )
2024-09-25 20:36:22 +08:00
2024-09-19 22:06:46 +08:00
trained_tokens , step = 0 , 0
tokens_per_step = data_loader . num_global_micro_batches * data_loader . micro_batch_size * SEQ_LEN
2024-09-25 20:36:22 +08:00
dist . barrier ( )
2024-09-25 21:33:20 +08:00
#TODO: Add Context Parallelism
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
2024-09-26 21:45:53 +08:00
#TODO: Check convergence
#TODO: Try multi-nodes
2024-09-25 21:17:05 +08:00
#TODO: Add activation checkpointing
#TODO: add gradient accumulation
2024-09-24 05:14:48 +08:00
2024-09-25 20:36:22 +08:00
while trained_tokens < MAX_TOKENS :
data_loader . set_epoch ( step )
2024-09-19 22:06:46 +08:00
optimizer . zero_grad ( )
2024-09-25 21:17:05 +08:00
2024-09-25 21:33:20 +08:00
if pgm . process_group_manager . pp_world_size > 1 :
2024-09-25 21:17:05 +08:00
loss = train_step_pipeline_afab ( model , data_loader , tensor_shapes , device )
else :
loss = train_step ( model , data_loader , device )
2024-09-26 22:00:06 +08:00
if pgm . process_group_manager . dp_world_size > 1 or pgm . process_group_manager . cp_world_size > 1 :
all_reduce_grads_across_dp_cp_ranks ( )
2024-09-25 21:17:05 +08:00
2024-09-19 22:06:46 +08:00
optimizer . step ( )
trained_tokens + = tokens_per_step
step + = 1
2024-09-24 21:43:22 +08:00
2024-09-25 22:12:31 +08:00
if pgm . process_group_manager . global_rank == 0 :
2024-09-25 21:33:20 +08:00
print ( f " [rank { pgm . process_group_manager . global_rank } ] Step: { step } , Loss: { loss : .4f } , Tokens: { trained_tokens } / { MAX_TOKENS } " )
2024-09-25 22:19:16 +08:00
if pgm . process_group_manager . global_rank == 0 and args . use_wandb :
wandb . log ( { " loss " : loss , " trained_tokens " : trained_tokens } )
if pgm . process_group_manager . global_rank == 0 and args . use_wandb :
wandb . finish ( )
2024-09-24 21:43:22 +08:00
dist . destroy_process_group ( )