diff --git a/create_config.py b/create_config.py index e3f1334..ad55038 100644 --- a/create_config.py +++ b/create_config.py @@ -14,6 +14,7 @@ def create_single_config( out_dir: str, tp: int, cp: int, + dp: int, pp: int, pp_engine: str, model_name: str, @@ -50,6 +51,7 @@ def create_single_config( config_content['distributed']['tp_size'] = tp config_content['distributed']['cp_size'] = cp + config_content['distributed']['dp_size'] = dp config_content['distributed']['pp_size'] = pp config_content['distributed']['pp_engine'] = pp_engine @@ -76,6 +78,7 @@ if __name__ == "__main__": parser.add_argument("--out_dir", type=str, help="Output directory to store the configs", default="tmp") parser.add_argument("--tp", type=int, help="number of tensor parallelism", default=1) parser.add_argument("--cp", type=int, help="number of context parallelism", default=1) + parser.add_argument("--dp", type=int, help="number of data parallelism", default=1) parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1) parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab") parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct") diff --git a/train.py b/train.py index 821e6c6..1137ef5 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,7 @@ import inspect import datetime import json import time +import datetime import argparse import torch.nn.functional as F import torch, torch.distributed as dist @@ -88,6 +89,8 @@ if __name__ == "__main__": NUM_PROC = config["dataset"]["num_proc"] USE_WANDB = config["logging"]["use_wandb"] TP_SIZE = config["distributed"]["tp_size"] + CP_SIZE = config["distributed"]["cp_size"] + DP_SIZE = config["distributed"]["dp_size"] PP_SIZE = config["distributed"]["pp_size"] PP_ENGINE = config["distributed"]["pp_engine"] LOAD_PATH = config["checkpoint"]["load_path"] @@ -127,17 +130,22 @@ if __name__ == "__main__": model = Llama(config=model_config) print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank) + set_all_seed(SEED) + start_time = time.time() data_loader = MicroBatchDataLoader( micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=DATASET_NAME, tokenizer_name=MODEL_NAME, - grad_acc=GRAD_ACC, + grad_acc_steps=GRAD_ACC_STEPS, num_workers=NUM_WORKERS, num_proc=NUM_PROC, num_samples=NUM_SAMPLES ) + + dist.barrier() + print("init dataloader time:", time.time()-start_time, is_print_rank=is_wandb_rank) tokens_per_step = data_loader.global_batch_size * SEQ_LEN @@ -166,6 +174,17 @@ if __name__ == "__main__": start_time = time.time() + model_config = AutoConfig.from_pretrained(MODEL_NAME) + model_config.num_hidden_layers = config["model"]["num_hidden_layers"] + model_config.num_attention_heads = config["model"]["num_attention_heads"] + model_config.num_key_value_heads = config["model"]["num_key_value_heads"] + model_config.max_position_embeddings = SEQ_LEN + + start_time = time.time() + model = Llama(config=model_config) + print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank) + dist.barrier() + if pgm.process_group_manager.tp_world_size > 1: TensorParallel(model) @@ -178,11 +197,10 @@ if __name__ == "__main__": if pgm.process_group_manager.cp_dp_world_size > 1: model = DataParallel(model) - print("init parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank) + print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank) start_time = time.time() model.train() - print("model to device time:", time.time()-start_time, is_print_rank=is_wandb_rank) tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size) @@ -199,12 +217,7 @@ if __name__ == "__main__": step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH) dist.barrier() - - # #TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0) - # #TODO: Check convergence - # #TODO: Try multi-nodes - # #TODO: Add activation checkpointing - # #TODO: add gradient accumulation + #TODO: Add activation checkpointing while MAX_TOKENS is None or trained_tokens < MAX_TOKENS: #TODO: Add epoch support