fix issue with too many rank reading to HF library
This commit is contained in:
parent
52a9779345
commit
b6267c768e
26
train.py
26
train.py
@ -100,7 +100,7 @@ if __name__ == "__main__":
|
|||||||
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
|
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
|
||||||
|
|
||||||
set_all_seed(config["training"]["seed"])
|
set_all_seed(config["training"]["seed"])
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
data_loader = MicroBatchDataLoader(
|
data_loader = MicroBatchDataLoader(
|
||||||
micro_batch_size=config["training"]["micro_batch_size"],
|
micro_batch_size=config["training"]["micro_batch_size"],
|
||||||
@ -141,11 +141,22 @@ if __name__ == "__main__":
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = AutoConfig.from_pretrained(config["model"]["name"])
|
if pgm.process_group_manager.global_rank == 0:
|
||||||
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
print(f"rank: {pgm.process_group_manager.global_rank}: Creating model config")
|
||||||
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
model_config = AutoConfig.from_pretrained(config["model"]["name"])
|
||||||
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
|
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
||||||
model_config.max_position_embeddings = config["training"]["seq_length"]
|
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
||||||
|
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
|
||||||
|
model_config.max_position_embeddings = config["training"]["seq_length"]
|
||||||
|
objects = [model_config]
|
||||||
|
else:
|
||||||
|
print(f"rank: {pgm.process_group_manager.global_rank}: Initialized model_config as None")
|
||||||
|
objects = [None]
|
||||||
|
|
||||||
|
dist.broadcast_object_list(objects, src=0, device=device)
|
||||||
|
model_config = objects[0]
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@ -166,7 +177,8 @@ if __name__ == "__main__":
|
|||||||
model = apply_context_parallel(model)
|
model = apply_context_parallel(model)
|
||||||
|
|
||||||
model.to(dtype).to(device)
|
model.to(dtype).to(device)
|
||||||
|
|
||||||
|
# TODO: why not just dp_world_size ?
|
||||||
if pgm.process_group_manager.cp_dp_world_size > 1:
|
if pgm.process_group_manager.cp_dp_world_size > 1:
|
||||||
model = DataParallelBucket(model)
|
model = DataParallelBucket(model)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user