lower timeout in train

This commit is contained in:
ferdinand.mom 2024-11-04 14:28:01 +00:00
parent 4e1a6f8cdd
commit 7c381a61eb

View File

@ -111,25 +111,14 @@ if __name__ == "__main__":
else:
device = torch.device("cpu")
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=10))
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE)
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
dist.barrier()
set_all_seed(SEED)
model_config = AutoConfig.from_pretrained(MODEL_NAME)
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
model_config.num_attention_heads = config["model"]["num_attention_heads"]
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
model_config.max_position_embeddings = SEQ_LEN
start_time = time.time()
model = Llama(config=model_config)
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
dist.barrier()
start_time = time.time()
data_loader = MicroBatchDataLoader(
micro_batch_size=MICRO_BATCH_SIZE,
@ -172,6 +161,17 @@ if __name__ == "__main__":
start_time = time.time()
model_config = AutoConfig.from_pretrained(MODEL_NAME)
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
model_config.num_attention_heads = config["model"]["num_attention_heads"]
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
model_config.max_position_embeddings = SEQ_LEN
start_time = time.time()
model = Llama(config=model_config)
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
dist.barrier()
if pgm.process_group_manager.tp_world_size > 1:
TensorParallel(model)
@ -184,11 +184,10 @@ if __name__ == "__main__":
if pgm.process_group_manager.cp_dp_world_size > 1:
model = DataParallel(model)
print("init parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
start_time = time.time()
model.train()
print("model to device time:", time.time()-start_time, is_print_rank=is_wandb_rank)
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size)
@ -205,7 +204,7 @@ if __name__ == "__main__":
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)
dist.barrier()
# #TODO: Add activation checkpointing
#TODO: Add activation checkpointing
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
#TODO: Add epoch support