From b6267c768e105460684b23acf077ca90be2c8fd8 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 2 Dec 2024 15:36:47 +0000 Subject: [PATCH] fix issue with too many rank reading to HF library --- train.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 85344be..0e73242 100644 --- a/train.py +++ b/train.py @@ -100,7 +100,7 @@ if __name__ == "__main__": is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage set_all_seed(config["training"]["seed"]) - + start_time = time.time() data_loader = MicroBatchDataLoader( micro_batch_size=config["training"]["micro_batch_size"], @@ -141,11 +141,22 @@ if __name__ == "__main__": }, ) - model_config = AutoConfig.from_pretrained(config["model"]["name"]) - model_config.num_hidden_layers = config["model"]["num_hidden_layers"] - model_config.num_attention_heads = config["model"]["num_attention_heads"] - model_config.num_key_value_heads = config["model"]["num_key_value_heads"] - model_config.max_position_embeddings = config["training"]["seq_length"] + if pgm.process_group_manager.global_rank == 0: + print(f"rank: {pgm.process_group_manager.global_rank}: Creating model config") + model_config = AutoConfig.from_pretrained(config["model"]["name"]) + model_config.num_hidden_layers = config["model"]["num_hidden_layers"] + model_config.num_attention_heads = config["model"]["num_attention_heads"] + model_config.num_key_value_heads = config["model"]["num_key_value_heads"] + model_config.max_position_embeddings = config["training"]["seq_length"] + objects = [model_config] + else: + print(f"rank: {pgm.process_group_manager.global_rank}: Initialized model_config as None") + objects = [None] + + dist.broadcast_object_list(objects, src=0, device=device) + model_config = objects[0] + + dist.barrier() start_time = time.time() @@ -166,7 +177,8 @@ if __name__ == "__main__": model = apply_context_parallel(model) model.to(dtype).to(device) - + + # TODO: why not just dp_world_size ? if pgm.process_group_manager.cp_dp_world_size > 1: model = DataParallelBucket(model)