From 83ddda2ce8ee3f8ed2bb5cc35d29d6f2d140b7a6 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 18 Oct 2024 14:59:39 +0000 Subject: [PATCH] leave out CP integration at the very end --- train.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index a8b5d7d..b025284 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,6 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 -- #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 """ -import multiprocessing import os import torch.nn.functional as F import torch, torch.distributed as dist @@ -26,10 +25,9 @@ from utils import set_all_seed, print from src.distributed.process_group_manager import setup_process_group_manager from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from src.parallel.data_parallel.data_parallel_bucket import DataParallel -# from src.parallel.context_parallel import ContextParallel -from model import LLaMA +from src.parallel.context_parallel import ContextParallel +from model import Llama import wandb -import multiprocessing class MicroBatchDataLoader(DataLoader): def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, grad_acc = 1, split="train", num_samples=None, num_workers=0): @@ -205,9 +203,7 @@ if __name__ == "__main__": config.num_attention_heads = 16 config.num_key_value_heads = 4 - model = LLaMA( - config=config - ) + model = Llama(config=config) if pgm.process_group_manager.global_rank == 0 and args.use_wandb: wandb.init( @@ -231,13 +227,14 @@ if __name__ == "__main__": TensorParallel(model) # if pgm.process_group_manager.cp_size > 1: - # model = ContextParallel(model, config) + #TODO: do at the very end when we have fix convergence issue + # model = ContextParallel(model, config) if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, config) if pgm.process_group_manager.dp_world_size > 1: - model = DataParallel(model, pgm.process_group_manager.dp_group) + model = DataParallel(model) model.to(device) model.train() @@ -267,6 +264,7 @@ if __name__ == "__main__": # average the loss across all DP/CP ranks if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1: + #TODO: use all_reduce function from distributed_primitives.py loss_tensor = torch.tensor([loss], dtype=torch.float32, device=device) handle = dist.all_reduce(loss_tensor, group=pgm.process_group_manager.cp_dp_group, async_op=True, op=dist.ReduceOp.AVG)