small changes unrelated to dp+pp sync grad fix

This commit is contained in:
ferdinand.mom 2024-11-04 15:00:43 +00:00
parent cce11da2cb
commit 0bfc06506a
2 changed files with 25 additions and 9 deletions

View File

@ -14,6 +14,7 @@ def create_single_config(
out_dir: str,
tp: int,
cp: int,
dp: int,
pp: int,
pp_engine: str,
model_name: str,
@ -50,6 +51,7 @@ def create_single_config(
config_content['distributed']['tp_size'] = tp
config_content['distributed']['cp_size'] = cp
config_content['distributed']['dp_size'] = dp
config_content['distributed']['pp_size'] = pp
config_content['distributed']['pp_engine'] = pp_engine
@ -76,6 +78,7 @@ if __name__ == "__main__":
parser.add_argument("--out_dir", type=str, help="Output directory to store the configs", default="tmp")
parser.add_argument("--tp", type=int, help="number of tensor parallelism", default=1)
parser.add_argument("--cp", type=int, help="number of context parallelism", default=1)
parser.add_argument("--dp", type=int, help="number of data parallelism", default=1)
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")

View File

@ -13,6 +13,7 @@ import inspect
import datetime
import json
import time
import datetime
import argparse
import torch.nn.functional as F
import torch, torch.distributed as dist
@ -88,6 +89,8 @@ if __name__ == "__main__":
NUM_PROC = config["dataset"]["num_proc"]
USE_WANDB = config["logging"]["use_wandb"]
TP_SIZE = config["distributed"]["tp_size"]
CP_SIZE = config["distributed"]["cp_size"]
DP_SIZE = config["distributed"]["dp_size"]
PP_SIZE = config["distributed"]["pp_size"]
PP_ENGINE = config["distributed"]["pp_engine"]
LOAD_PATH = config["checkpoint"]["load_path"]
@ -127,17 +130,22 @@ if __name__ == "__main__":
model = Llama(config=model_config)
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
set_all_seed(SEED)
start_time = time.time()
data_loader = MicroBatchDataLoader(
micro_batch_size=MICRO_BATCH_SIZE,
seq_length=SEQ_LEN,
dataset_name=DATASET_NAME,
tokenizer_name=MODEL_NAME,
grad_acc=GRAD_ACC,
grad_acc_steps=GRAD_ACC_STEPS,
num_workers=NUM_WORKERS,
num_proc=NUM_PROC,
num_samples=NUM_SAMPLES
)
dist.barrier()
print("init dataloader time:", time.time()-start_time, is_print_rank=is_wandb_rank)
tokens_per_step = data_loader.global_batch_size * SEQ_LEN
@ -166,6 +174,17 @@ if __name__ == "__main__":
start_time = time.time()
model_config = AutoConfig.from_pretrained(MODEL_NAME)
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
model_config.num_attention_heads = config["model"]["num_attention_heads"]
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
model_config.max_position_embeddings = SEQ_LEN
start_time = time.time()
model = Llama(config=model_config)
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
dist.barrier()
if pgm.process_group_manager.tp_world_size > 1:
TensorParallel(model)
@ -178,11 +197,10 @@ if __name__ == "__main__":
if pgm.process_group_manager.cp_dp_world_size > 1:
model = DataParallel(model)
print("init parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
start_time = time.time()
model.train()
print("model to device time:", time.time()-start_time, is_print_rank=is_wandb_rank)
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size)
@ -199,12 +217,7 @@ if __name__ == "__main__":
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)
dist.barrier()
# #TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
# #TODO: Check convergence
# #TODO: Try multi-nodes
# #TODO: Add activation checkpointing
# #TODO: add gradient accumulation
#TODO: Add activation checkpointing
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
#TODO: Add epoch support