add fuse adam
This commit is contained in:
parent
7996a318c1
commit
486c1763a6
@ -25,7 +25,8 @@ def create_single_config(
|
||||
mbs: int,
|
||||
seq_len: int,
|
||||
exp_name: str,
|
||||
use_wandb: bool = False
|
||||
use_wandb: bool = False,
|
||||
use_fused_adam: bool = False
|
||||
):
|
||||
run_path = os.path.join(out_dir, exp_name)
|
||||
|
||||
@ -45,6 +46,7 @@ def create_single_config(
|
||||
config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
|
||||
config_content["model"]["num_attention_heads"] = tmp_model_config.num_attention_heads if num_attention_heads is None else num_attention_heads
|
||||
config_content["model"]["num_key_value_heads"] = tmp_model_config.num_key_value_heads if num_key_value_heads is None else num_key_value_heads
|
||||
config_content["model"]["use_fused_adam"] = use_fused_adam
|
||||
del tmp_model_config
|
||||
|
||||
config_content['distributed']['tp_size'] = tp
|
||||
@ -88,6 +90,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024)
|
||||
parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging")
|
||||
parser.add_argument("--use_fused_adam", action="store_true", help="Use fused adam")
|
||||
|
||||
args=parser.parse_args()
|
||||
|
||||
@ -107,4 +110,5 @@ if __name__ == "__main__":
|
||||
seq_len=args.seq_len,
|
||||
exp_name=args.exp_name,
|
||||
use_wandb=args.use_wandb,
|
||||
use_fused_adam=args.use_fused_adam
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"cp_size": 1,
|
||||
"pp_size": 1,
|
||||
"dp_size": 1,
|
||||
"pp_engine": "afab",
|
||||
"master_addr": "localhost",
|
||||
"master_port": 29500,
|
||||
"backend": "nccl",
|
||||
@ -15,7 +16,8 @@
|
||||
"num_attention_heads": 16,
|
||||
"num_key_value_heads": 4,
|
||||
"dtype": "bfloat16",
|
||||
"use_flash_attention": true
|
||||
"use_flash_attention": true,
|
||||
"use_fused_adam": true
|
||||
},
|
||||
"training": {
|
||||
"seed": 42,
|
||||
|
||||
11
train.py
11
train.py
@ -8,7 +8,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
|
||||
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
@ -184,8 +184,15 @@ if __name__ == "__main__":
|
||||
print("model to device time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
|
||||
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size)
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
extra_args = dict()
|
||||
if config["model"]["use_fused_adam"]:
|
||||
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
||||
use_fused = fused_available and device == 'cuda'
|
||||
extra_args = dict(fused=True) if use_fused else dict()
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
|
||||
|
||||
trained_tokens, step = 0, 0
|
||||
if LOAD_PATH:
|
||||
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user