add fuse adam

This commit is contained in:
ferdinand.mom 2024-11-02 01:18:56 +00:00
parent 7996a318c1
commit 486c1763a6
3 changed files with 17 additions and 4 deletions

View File

@ -25,7 +25,8 @@ def create_single_config(
mbs: int,
seq_len: int,
exp_name: str,
use_wandb: bool = False
use_wandb: bool = False,
use_fused_adam: bool = False
):
run_path = os.path.join(out_dir, exp_name)
@ -45,6 +46,7 @@ def create_single_config(
config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
config_content["model"]["num_attention_heads"] = tmp_model_config.num_attention_heads if num_attention_heads is None else num_attention_heads
config_content["model"]["num_key_value_heads"] = tmp_model_config.num_key_value_heads if num_key_value_heads is None else num_key_value_heads
config_content["model"]["use_fused_adam"] = use_fused_adam
del tmp_model_config
config_content['distributed']['tp_size'] = tp
@ -88,6 +90,7 @@ if __name__ == "__main__":
parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024)
parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp")
parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging")
parser.add_argument("--use_fused_adam", action="store_true", help="Use fused adam")
args=parser.parse_args()
@ -107,4 +110,5 @@ if __name__ == "__main__":
seq_len=args.seq_len,
exp_name=args.exp_name,
use_wandb=args.use_wandb,
use_fused_adam=args.use_fused_adam
)

View File

@ -4,6 +4,7 @@
"cp_size": 1,
"pp_size": 1,
"dp_size": 1,
"pp_engine": "afab",
"master_addr": "localhost",
"master_port": 29500,
"backend": "nccl",
@ -15,7 +16,8 @@
"num_attention_heads": 16,
"num_key_value_heads": 4,
"dtype": "bfloat16",
"use_flash_attention": true
"use_flash_attention": true,
"use_fused_adam": true
},
"training": {
"seed": 42,

View File

@ -8,7 +8,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
"""
import inspect
import os
import json
import time
@ -184,8 +184,15 @@ if __name__ == "__main__":
print("model to device time:", time.time()-start_time, is_print_rank=is_wandb_rank)
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
extra_args = dict()
if config["model"]["use_fused_adam"]:
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
trained_tokens, step = 0, 0
if LOAD_PATH:
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)