2024-10-16 23:58:35 +08:00
""" Training script for LLaMA model.
2024-11-22 01:48:26 +08:00
CUDA_DEVICE_MAX_CONNECTIONS = 1 torchrun - - nproc_per_node 4 - - master_addr localhost - - master_port 25500 train . py - - config tmp / dummy / llama2_7b_benchmark . json
2024-12-19 13:48:29 +08:00
CUDA_DEVICE_MAX_CONNECTIONS = 1 debugpy - run - p 5678 - m torch . distributed . run - - - - nproc_per_node = 4 - - nnodes = 1 - - rdzv_backend = c10d - - rdzv_endpoint = localhost : 29400 train . py - - config tmp / dummy / llama2_7b_benchmark . json
2024-10-16 23:58:35 +08:00
"""
2024-09-19 22:06:46 +08:00
import os
2024-11-04 22:35:36 +08:00
import inspect
2024-11-04 22:40:54 +08:00
import datetime
2024-10-29 23:44:35 +08:00
import json
2024-10-23 08:38:27 +08:00
import time
2024-11-04 23:00:43 +08:00
import datetime
2024-10-23 08:38:27 +08:00
import argparse
2024-09-25 21:17:05 +08:00
import torch . nn . functional as F
2024-09-19 22:06:46 +08:00
import torch , torch . distributed as dist
from torch . optim import AdamW
2024-10-10 23:12:14 +08:00
from transformers import AutoConfig
2024-11-05 00:52:08 +08:00
from picotron . context_parallel . context_parallel import apply_context_parallel
2024-12-01 11:42:04 +08:00
from picotron . tensor_parallel . tensor_parallel import apply_tensor_parallel
2024-11-04 23:29:26 +08:00
import picotron . process_group_manager as pgm
2024-12-19 13:48:29 +08:00
from picotron . utils import average_loss_across_dp_cp_ranks , set_all_seed , print , to_readable_format , get_mfu , get_num_params
2024-12-01 11:40:56 +08:00
from picotron . checkpoint import CheckpointManager
2024-12-02 03:45:11 +08:00
from picotron . checkpoint import init_model_with_dematerialized_weights , init_model_with_materialized_weights
2024-11-04 23:36:01 +08:00
from picotron . data import MicroBatchDataLoader
2024-11-04 23:29:26 +08:00
from picotron . process_group_manager import setup_process_group_manager
2024-11-05 00:10:47 +08:00
from picotron . pipeline_parallel . pipeline_parallel import train_step_pipeline_1f1b , train_step_pipeline_afab , PipelineParallel
2024-11-05 00:26:11 +08:00
from picotron . data_parallel . data_parallel import DataParallelBucket
2024-11-05 00:10:47 +08:00
from picotron . model import Llama
2024-09-25 22:19:16 +08:00
import wandb
2024-11-30 00:38:42 +08:00
2024-10-16 23:58:35 +08:00
def train_step ( model , data_loader , device ) :
acc_loss = 0.0
2024-10-29 21:42:38 +08:00
requires_grad_sync = pgm . process_group_manager . cp_dp_world_size > 1
2024-11-04 23:06:29 +08:00
for i in range ( data_loader . grad_acc_steps ) :
2024-10-28 15:46:23 +08:00
# get the next batch
batch = next ( data_loader )
input_ids = batch [ " input_ids " ] . to ( device )
target_ids = batch [ " target_ids " ] . to ( device )
# disable gradient synchronization for all but the last micro-batch
2024-10-29 21:42:38 +08:00
if requires_grad_sync :
2024-11-04 23:06:29 +08:00
model . require_backward_grad_sync = ( i == data_loader . grad_acc_steps - 1 )
2024-10-28 15:46:23 +08:00
2024-10-18 13:13:44 +08:00
outputs = model ( input_ids = input_ids )
2024-09-25 21:17:05 +08:00
2024-10-18 13:13:44 +08:00
# compute the loss
2024-10-15 20:43:28 +08:00
batch_size , seq_len = input_ids . shape
2024-10-18 13:13:44 +08:00
target_ids = target_ids . reshape ( - 1 )
outputs = outputs . view ( seq_len * batch_size , - 1 )
2024-11-04 23:06:29 +08:00
loss = F . cross_entropy ( outputs , target_ids , reduction = ' mean ' ) / data_loader . grad_acc_steps
2024-10-18 13:13:44 +08:00
2024-09-25 21:17:05 +08:00
loss . backward ( )
2024-09-25 20:36:22 +08:00
2024-10-18 13:13:44 +08:00
acc_loss + = loss . item ( )
2024-09-19 22:06:46 +08:00
2024-10-16 23:58:35 +08:00
return acc_loss
2024-09-19 22:06:46 +08:00
2024-11-30 00:38:42 +08:00
if __name__ == " __main__ " :
2024-09-23 18:28:01 +08:00
parser = argparse . ArgumentParser ( )
2024-10-29 23:44:35 +08:00
parser . add_argument ( " --config " , type = str , default = " " , help = " Path to config file " )
2024-09-23 18:28:01 +08:00
args = parser . parse_args ( )
2024-10-29 23:44:35 +08:00
with open ( args . config , " r " ) as f :
config = json . load ( f )
2024-09-23 18:28:01 +08:00
2024-10-29 23:44:35 +08:00
os . environ [ " OMP_NUM_THREADS " ] = config [ " environment " ] [ " OMP_NUM_THREADS " ]
os . environ [ " TOKENIZERS_PARALLELISM " ] = config [ " environment " ] [ " TOKENIZERS_PARALLELISM " ]
2024-12-02 04:00:05 +08:00
os . environ [ " FLASH_ATTEN " ] = config [ " environment " ] [ " FLASH_ATTEN " ]
2024-10-29 23:44:35 +08:00
os . environ [ " DEVICE " ] = " cpu " if config [ " distributed " ] [ " use_cpu " ] else " cuda "
2024-12-17 20:52:44 +08:00
if config [ " environment " ] [ " HF_TOKEN " ] is None : raise ValueError ( " HF_TOKEN is not set in the config file " )
os . environ [ " HF_TOKEN " ] = config [ " environment " ] [ " HF_TOKEN " ]
2024-12-02 04:00:05 +08:00
dtype = torch . bfloat16 if torch . cuda . is_available ( ) and torch . cuda . is_bf16_supported ( ) and not config [ " distributed " ] [ " use_cpu " ] else torch . float32
2024-10-23 07:38:44 +08:00
assert ( dtype == torch . bfloat16 and os . getenv ( " FLASH_ATTEN " ) == " 1 " ) or os . getenv ( " FLASH_ATTEN " ) != " 1 " , " Kernel operations requires dtype=torch.bfloat16 "
2024-10-29 23:44:35 +08:00
2024-09-26 18:27:20 +08:00
local_rank = int ( os . environ [ " LOCAL_RANK " ] )
2024-11-04 22:40:54 +08:00
global_rank = int ( os . environ [ " RANK " ] )
2024-09-26 18:27:20 +08:00
world_size = int ( os . environ [ " WORLD_SIZE " ] )
2024-11-04 22:40:54 +08:00
2024-10-29 23:44:35 +08:00
backend = " gloo " if config [ " distributed " ] [ " use_cpu " ] else " nccl "
2024-09-26 18:27:20 +08:00
2024-12-02 04:00:05 +08:00
assert config [ " training " ] [ " seq_length " ] % config [ " distributed " ] [ " cp_size " ] == 0 , " seq_length must be divisible by cp_size for Context Parallelism "
assert world_size == config [ " distributed " ] [ " tp_size " ] * config [ " distributed " ] [ " pp_size " ] * config [ " distributed " ] [ " dp_size " ] * config [ " distributed " ] [ " cp_size " ] , " world_size must be equal to tp_size * pp_size * dp_size * cp_size "
2024-10-15 20:43:28 +08:00
2024-09-26 18:27:20 +08:00
if backend == " nccl " :
torch . cuda . set_device ( local_rank )
device = torch . device ( " cuda " , local_rank )
else :
device = torch . device ( " cpu " )
2024-10-29 23:44:35 +08:00
2024-11-04 22:40:54 +08:00
dist . init_process_group ( rank = global_rank , world_size = world_size , backend = backend , init_method = f " env:// " , timeout = datetime . timedelta ( minutes = 3 ) )
2024-12-02 04:00:05 +08:00
setup_process_group_manager (
tp_size = config [ " distributed " ] [ " tp_size " ] ,
cp_size = config [ " distributed " ] [ " cp_size " ] ,
pp_size = config [ " distributed " ] [ " pp_size " ] ,
dp_size = config [ " distributed " ] [ " dp_size " ]
)
2024-10-29 21:42:38 +08:00
is_wandb_rank = pgm . process_group_manager . tp_rank == 0 and pgm . process_group_manager . dp_rank == 0 and pgm . process_group_manager . cp_rank == 0 and pgm . process_group_manager . pp_is_last_stage
2024-09-23 18:28:01 +08:00
2024-12-02 04:00:05 +08:00
set_all_seed ( config [ " training " ] [ " seed " ] )
2024-12-02 23:36:47 +08:00
2024-10-30 21:53:50 +08:00
start_time = time . time ( )
data_loader = MicroBatchDataLoader (
2024-12-02 04:00:05 +08:00
micro_batch_size = config [ " training " ] [ " micro_batch_size " ] ,
seq_length = config [ " training " ] [ " seq_length " ] ,
dataset_name = config [ " dataset " ] [ " name " ] ,
tokenizer_name = config [ " model " ] [ " name " ] ,
grad_acc_steps = config [ " training " ] [ " gradient_accumulation_steps " ] ,
2024-12-03 22:20:44 +08:00
device = device ,
2024-12-02 04:00:05 +08:00
num_workers = config [ " dataset " ] [ " num_workers " ] ,
num_proc = config [ " dataset " ] [ " num_proc " ] ,
2024-12-19 13:48:29 +08:00
num_samples = config [ " training " ] . get ( " num_samples " , None ) ,
subset_name = config [ " dataset " ] . get ( " subset_name " , None ) ,
split = config [ " dataset " ] . get ( " split " , " train " )
2024-10-30 21:53:50 +08:00
)
2024-11-04 23:00:43 +08:00
2024-11-05 00:57:00 +08:00
dist . barrier ( )
2024-11-04 23:00:43 +08:00
2024-11-19 01:36:51 +08:00
print ( f " init dataloader time: { time . time ( ) - start_time : .2f } s " , is_print_rank = is_wandb_rank )
2024-12-02 04:00:05 +08:00
tokens_per_step = data_loader . global_batch_size * config [ " training " ] [ " seq_length " ]
2024-09-25 22:19:16 +08:00
2024-10-30 21:53:50 +08:00
if pgm . process_group_manager . global_rank == 0 :
print ( " Tokens per step: " , to_readable_format ( tokens_per_step ) , is_print_rank = is_wandb_rank )
2024-12-02 04:00:05 +08:00
if is_wandb_rank and config [ " logging " ] [ " use_wandb " ] :
2024-09-25 22:19:16 +08:00
wandb . init (
project = " picotron " ,
2024-11-22 01:48:26 +08:00
name = f " { config [ ' logging ' ] [ ' run_name ' ] } _ { to_readable_format ( tokens_per_step ) } _ { pgm . process_group_manager } " ,
2024-09-25 22:19:16 +08:00
config = {
2024-12-03 02:12:02 +08:00
" tensor_parallel_size " : pgm . process_group_manager . tp_world_size ,
" context_parallel_size " : pgm . process_group_manager . cp_world_size ,
" pipeline_parallel_size " : pgm . process_group_manager . pp_world_size ,
" data_parallel_size " : pgm . process_group_manager . dp_world_size ,
2024-10-29 23:44:35 +08:00
" model " : config [ " model " ] [ " name " ] ,
" dataset " : config [ " dataset " ] [ " name " ] ,
2024-12-02 04:00:05 +08:00
" max_tokens " : config [ " training " ] [ " max_tokens " ] ,
" learning_rate " : config [ " training " ] [ " learning_rate " ] ,
" seed " : config [ " training " ] [ " seed " ] ,
2024-10-30 21:53:50 +08:00
" micro_batch_size " : data_loader . micro_batch_size ,
" global_batch_size " : data_loader . global_batch_size ,
2024-11-04 23:06:29 +08:00
" gradient_accumulation " : data_loader . grad_acc_steps ,
2024-09-25 22:19:16 +08:00
} ,
)
2024-10-14 17:26:31 +08:00
2024-12-02 23:36:47 +08:00
if pgm . process_group_manager . global_rank == 0 :
2024-12-17 23:41:00 +08:00
print ( f " rank { pgm . process_group_manager . global_rank } : Creating model config " )
2024-12-02 23:36:47 +08:00
model_config = AutoConfig . from_pretrained ( config [ " model " ] [ " name " ] )
2024-12-19 15:05:16 +08:00
# twist the model structure if specified in the config file
model_config . num_hidden_layers = model_config . num_hidden_layers if " num_hidden_layers " not in config [ " model " ] else config [ " model " ] [ " num_hidden_layers " ]
model_config . num_attention_heads = model_config . num_attention_heads if " num_attention_heads " not in config [ " model " ] else config [ " model " ] [ " num_attention_heads " ]
model_config . num_key_value_heads = model_config . num_key_value_heads if " num_key_value_heads " not in config [ " model " ] else config [ " model " ] [ " num_key_value_heads " ]
2024-12-02 23:36:47 +08:00
model_config . max_position_embeddings = config [ " training " ] [ " seq_length " ]
objects = [ model_config ]
else :
objects = [ None ]
dist . broadcast_object_list ( objects , src = 0 , device = device )
model_config = objects [ 0 ]
2024-12-17 23:41:00 +08:00
print ( f " rank { pgm . process_group_manager . global_rank } : Broadcasting model_config to all ranks " , is_print_rank = pgm . process_group_manager . global_rank == 0 )
2024-12-02 23:36:47 +08:00
dist . barrier ( )
2024-11-04 23:00:43 +08:00
2024-12-19 00:50:36 +08:00
print ( f " rank { pgm . process_group_manager . global_rank } : Initializing model meta device " , is_print_rank = is_wandb_rank )
2024-11-04 23:00:43 +08:00
start_time = time . time ( )
2024-12-01 11:42:04 +08:00
2024-11-30 00:38:42 +08:00
with init_model_with_dematerialized_weights ( ) :
model = Llama ( config = model_config )
2024-11-05 00:57:00 +08:00
2024-12-01 11:42:04 +08:00
if pgm . process_group_manager . tp_world_size > 1 :
model = apply_tensor_parallel ( model )
2024-11-30 00:38:42 +08:00
2024-12-01 11:42:04 +08:00
if pgm . process_group_manager . pp_world_size > 1 :
model = PipelineParallel ( model , model_config )
2024-12-17 23:41:00 +08:00
model = init_model_with_materialized_weights ( model , model_config , save_dir = f " ./hf_model_safetensors/ { model_config . _name_or_path } " )
2024-09-25 20:36:22 +08:00
2024-12-02 04:26:40 +08:00
#TODO: load existing checkpoint here to continue pre-training
2024-12-02 04:00:05 +08:00
2024-11-05 00:52:08 +08:00
if pgm . process_group_manager . cp_world_size > 1 :
model = apply_context_parallel ( model )
2024-10-30 04:58:04 +08:00
model . to ( dtype ) . to ( device )
2024-12-02 23:36:47 +08:00
2024-12-17 20:52:44 +08:00
if pgm . process_group_manager . dp_world_size > 1 :
2024-11-05 00:26:11 +08:00
model = DataParallelBucket ( model )
2024-10-29 21:42:38 +08:00
2024-11-19 01:36:51 +08:00
print ( f " init model parallel time: { time . time ( ) - start_time : .2f } s " , is_print_rank = is_wandb_rank )
2024-10-29 21:42:38 +08:00
2024-09-25 20:36:22 +08:00
model . train ( )
2024-11-19 01:36:51 +08:00
num_params = get_num_params ( model )
print ( f " Number of parameters: { to_readable_format ( num_params ) } " , is_print_rank = is_wandb_rank )
2024-09-25 20:36:22 +08:00
2024-10-29 23:44:35 +08:00
tensor_shapes = ( data_loader . micro_batch_size , data_loader . seq_length_per_gpu , model_config . hidden_size )
2024-09-25 20:36:22 +08:00
2024-11-04 22:35:36 +08:00
extra_args = dict ( )
if config [ " model " ] [ " use_fused_adam " ] :
fused_available = ' fused ' in inspect . signature ( torch . optim . AdamW ) . parameters
use_fused = fused_available and device == ' cuda '
extra_args = dict ( fused = True ) if use_fused else dict ( )
2024-12-02 04:00:05 +08:00
optimizer = AdamW ( model . parameters ( ) , lr = config [ " training " ] [ " learning_rate " ] , * * extra_args )
2024-12-01 11:40:56 +08:00
checkpoint_manager = CheckpointManager ( )
2024-11-04 22:35:36 +08:00
2024-09-19 22:06:46 +08:00
trained_tokens , step = 0 , 0
2024-12-02 04:00:05 +08:00
if config [ " checkpoint " ] [ " load_path " ] :
step , trained_tokens = checkpoint_manager . load_checkpoint ( model , optimizer , config [ " checkpoint " ] [ " load_path " ] )
2024-10-28 13:19:59 +08:00
2024-09-25 20:36:22 +08:00
dist . barrier ( )
2024-09-24 05:14:48 +08:00
2024-12-02 04:00:05 +08:00
while config [ " training " ] [ " max_tokens " ] is None or trained_tokens < config [ " training " ] [ " max_tokens " ] :
2024-10-23 08:38:27 +08:00
step_start_time = time . time ( )
2024-09-19 22:06:46 +08:00
optimizer . zero_grad ( )
2024-09-25 21:17:05 +08:00
2024-09-25 21:33:20 +08:00
if pgm . process_group_manager . pp_world_size > 1 :
2024-12-02 04:00:05 +08:00
if config [ " distributed " ] [ " pp_engine " ] == " afab " :
2024-11-04 22:32:44 +08:00
loss = train_step_pipeline_afab ( model , data_loader , tensor_shapes , device , dtype )
2024-12-02 04:00:05 +08:00
elif config [ " distributed " ] [ " pp_engine " ] == " 1f1b " :
2024-11-04 22:32:44 +08:00
loss = train_step_pipeline_1f1b ( model , data_loader , tensor_shapes , device , dtype )
else :
2024-12-02 04:00:05 +08:00
raise ValueError ( f " Invalid pipeline parallel engine: { config [ ' distributed ' ] [ ' pp_engine ' ] } " )
2024-09-25 21:17:05 +08:00
else :
loss = train_step ( model , data_loader , device )
2024-10-29 21:42:38 +08:00
2024-12-19 13:48:29 +08:00
loss = average_loss_across_dp_cp_ranks ( loss , device )
2024-10-16 23:58:35 +08:00
2024-09-19 22:06:46 +08:00
optimizer . step ( )
trained_tokens + = tokens_per_step
step + = 1
2024-09-24 21:43:22 +08:00
2024-10-17 00:48:55 +08:00
if hasattr ( model , ' reset ' ) :
model . reset ( )
2024-10-23 08:38:27 +08:00
step_duration = time . time ( ) - step_start_time
2024-11-19 01:36:51 +08:00
tokens_per_second = tokens_per_step / step_duration
2024-11-22 01:48:26 +08:00
tokens_per_second_per_gpu = tokens_per_second / world_size
mfu = get_mfu ( tokens_per_second_per_gpu , num_params , model_config )
2024-10-17 00:48:55 +08:00
2024-10-29 21:42:38 +08:00
if is_wandb_rank :
2024-11-19 01:36:51 +08:00
print (
f " [rank { pgm . process_group_manager . global_rank } ] "
f " Step: { step : <5d } | "
f " Loss: { loss : 6.4f } | "
f " Global batch size: { to_readable_format ( tokens_per_step ) : >7s } | "
f " Tokens/s: { to_readable_format ( tokens_per_second ) : >7s } | "
2024-11-22 01:48:26 +08:00
f " Tokens/s/GPU: { to_readable_format ( tokens_per_second_per_gpu ) : >7s } | "
2024-12-02 04:34:43 +08:00
f " Tokens: { to_readable_format ( trained_tokens ) : >7s } { ( ' / ' + to_readable_format ( config [ ' training ' ] [ ' max_tokens ' ] ) ) if config [ ' training ' ] [ ' max_tokens ' ] else ' ' } | "
2024-11-19 01:36:51 +08:00
f " MFU: { mfu : 5.2f } % | "
f " Memory usage: { torch . cuda . memory_reserved ( ) / 1e9 : 6.2f } GB " ,
is_print_rank = is_wandb_rank
)
2024-09-25 22:19:16 +08:00
2024-12-02 04:00:05 +08:00
if config [ " logging " ] [ " use_wandb " ] :
wandb . log ( {
" loss " : loss ,
" tokens_per_step " : tokens_per_step ,
" tokens_per_second " : tokens_per_step / step_duration ,
2024-12-17 13:30:26 +08:00
" mfu " : mfu ,
" tokens_per_second_per_gpu " : tokens_per_second_per_gpu ,
2024-12-02 04:00:05 +08:00
" memory_usage " : torch . cuda . memory_reserved ( ) / 1e9 ,
" trained_tokens " : trained_tokens
} )
2024-10-27 10:22:36 +08:00
2024-12-02 04:00:05 +08:00
if step % config [ " checkpoint " ] [ " save_frequency " ] == 0 :
checkpoint_manager . save_checkpoint ( model , optimizer , step , trained_tokens , config [ " checkpoint " ] [ " save_dir " ] + f " / { step } " )
2024-10-28 13:19:59 +08:00
2024-12-02 04:00:05 +08:00
if step > = config [ " training " ] [ " total_train_steps " ] :
2024-10-27 10:22:36 +08:00
break
2024-09-25 22:19:16 +08:00
2024-12-02 04:00:05 +08:00
if is_wandb_rank and config [ " logging " ] [ " use_wandb " ] :
2024-09-25 22:19:16 +08:00
wandb . finish ( )
2024-11-30 00:38:42 +08:00
dist . destroy_process_group ( )