Merge branch 'main' into add-grad-acc-pp
This commit is contained in:
commit
cecdafe515
13
train.py
13
train.py
@ -10,6 +10,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import inspect
|
import inspect
|
||||||
|
import datetime
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
@ -117,6 +118,18 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
set_all_seed(SEED)
|
||||||
|
|
||||||
|
model_config = AutoConfig.from_pretrained(MODEL_NAME)
|
||||||
|
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
||||||
|
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
||||||
|
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
|
||||||
|
model_config.max_position_embeddings = SEQ_LEN
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
model = Llama(config=model_config)
|
||||||
|
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||||
|
|
||||||
set_all_seed(SEED)
|
set_all_seed(SEED)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user