diff --git a/train.py b/train.py index 2c27873..821e6c6 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 -- """ import os import inspect +import datetime import json import time import argparse @@ -94,9 +95,9 @@ if __name__ == "__main__": CHECKPOINT_FREQ = config["checkpoint"]["save_frequency"] local_rank = int(os.environ["LOCAL_RANK"]) + global_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) - host = os.environ["MASTER_ADDR"] - port = int(os.environ["MASTER_PORT"]) + backend = "gloo" if config["distributed"]["use_cpu"] else "nccl" assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" @@ -108,10 +109,12 @@ if __name__ == "__main__": else: device = torch.device("cpu") - dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}") + dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3)) setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE) is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage + dist.barrier() + set_all_seed(SEED) model_config = AutoConfig.from_pretrained(MODEL_NAME)