lower timeout in train
This commit is contained in:
parent
4e1a6f8cdd
commit
7c381a61eb
31
train.py
31
train.py
@ -111,25 +111,14 @@ if __name__ == "__main__":
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=10))
|
||||
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
|
||||
setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE)
|
||||
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
|
||||
|
||||
dist.barrier()
|
||||
|
||||
set_all_seed(SEED)
|
||||
|
||||
model_config = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
||||
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
||||
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
|
||||
model_config.max_position_embeddings = SEQ_LEN
|
||||
|
||||
start_time = time.time()
|
||||
model = Llama(config=model_config)
|
||||
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
dist.barrier()
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
data_loader = MicroBatchDataLoader(
|
||||
micro_batch_size=MICRO_BATCH_SIZE,
|
||||
@ -172,6 +161,17 @@ if __name__ == "__main__":
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
model_config = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
||||
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
||||
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
|
||||
model_config.max_position_embeddings = SEQ_LEN
|
||||
|
||||
start_time = time.time()
|
||||
model = Llama(config=model_config)
|
||||
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
dist.barrier()
|
||||
|
||||
if pgm.process_group_manager.tp_world_size > 1:
|
||||
TensorParallel(model)
|
||||
|
||||
@ -184,11 +184,10 @@ if __name__ == "__main__":
|
||||
if pgm.process_group_manager.cp_dp_world_size > 1:
|
||||
model = DataParallel(model)
|
||||
|
||||
print("init parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
start_time = time.time()
|
||||
|
||||
model.train()
|
||||
print("model to device time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
|
||||
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size)
|
||||
|
||||
@ -205,7 +204,7 @@ if __name__ == "__main__":
|
||||
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)
|
||||
|
||||
dist.barrier()
|
||||
# #TODO: Add activation checkpointing
|
||||
#TODO: Add activation checkpointing
|
||||
|
||||
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
|
||||
#TODO: Add epoch support
|
||||
|
||||
Loading…
Reference in New Issue
Block a user