fix issue with too many rank reading to HF library

This commit is contained in:
ferdinand.mom 2024-12-02 15:36:47 +00:00
parent 52a9779345
commit b6267c768e

View File

@ -141,11 +141,22 @@ if __name__ == "__main__":
}, },
) )
model_config = AutoConfig.from_pretrained(config["model"]["name"]) if pgm.process_group_manager.global_rank == 0:
model_config.num_hidden_layers = config["model"]["num_hidden_layers"] print(f"rank: {pgm.process_group_manager.global_rank}: Creating model config")
model_config.num_attention_heads = config["model"]["num_attention_heads"] model_config = AutoConfig.from_pretrained(config["model"]["name"])
model_config.num_key_value_heads = config["model"]["num_key_value_heads"] model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
model_config.max_position_embeddings = config["training"]["seq_length"] model_config.num_attention_heads = config["model"]["num_attention_heads"]
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
model_config.max_position_embeddings = config["training"]["seq_length"]
objects = [model_config]
else:
print(f"rank: {pgm.process_group_manager.global_rank}: Initialized model_config as None")
objects = [None]
dist.broadcast_object_list(objects, src=0, device=device)
model_config = objects[0]
dist.barrier()
start_time = time.time() start_time = time.time()
@ -167,6 +178,7 @@ if __name__ == "__main__":
model.to(dtype).to(device) model.to(dtype).to(device)
# TODO: why not just dp_world_size ?
if pgm.process_group_manager.cp_dp_world_size > 1: if pgm.process_group_manager.cp_dp_world_size > 1:
model = DataParallelBucket(model) model = DataParallelBucket(model)