diff --git a/train.py b/train.py index fdc9215..309653a 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 -- """ import os import inspect +import datetime import json import time import datetime @@ -117,6 +118,18 @@ if __name__ == "__main__": dist.barrier() + set_all_seed(SEED) + + model_config = AutoConfig.from_pretrained(MODEL_NAME) + model_config.num_hidden_layers = config["model"]["num_hidden_layers"] + model_config.num_attention_heads = config["model"]["num_attention_heads"] + model_config.num_key_value_heads = config["model"]["num_key_value_heads"] + model_config.max_position_embeddings = SEQ_LEN + + start_time = time.time() + model = Llama(config=model_config) + print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank) + set_all_seed(SEED) start_time = time.time()