Merge branch 'main' into add-grad-acc-pp
This commit is contained in:
commit
cecdafe515
13
train.py
13
train.py
@ -10,6 +10,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --
|
||||
"""
|
||||
import os
|
||||
import inspect
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
@ -117,6 +118,18 @@ if __name__ == "__main__":
|
||||
|
||||
dist.barrier()
|
||||
|
||||
set_all_seed(SEED)
|
||||
|
||||
model_config = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
||||
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
||||
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
|
||||
model_config.max_position_embeddings = SEQ_LEN
|
||||
|
||||
start_time = time.time()
|
||||
model = Llama(config=model_config)
|
||||
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
|
||||
set_all_seed(SEED)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user