diff --git a/create_config.py b/create_config.py index f02dd18..fb080e4 100644 --- a/create_config.py +++ b/create_config.py @@ -25,7 +25,8 @@ def create_single_config( mbs: int, seq_len: int, exp_name: str, - use_wandb: bool = False + use_wandb: bool = False, + use_fused_adam: bool = False ): run_path = os.path.join(out_dir, exp_name) @@ -45,6 +46,7 @@ def create_single_config( config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers config_content["model"]["num_attention_heads"] = tmp_model_config.num_attention_heads if num_attention_heads is None else num_attention_heads config_content["model"]["num_key_value_heads"] = tmp_model_config.num_key_value_heads if num_key_value_heads is None else num_key_value_heads + config_content["model"]["use_fused_adam"] = use_fused_adam del tmp_model_config config_content['distributed']['tp_size'] = tp @@ -88,6 +90,7 @@ if __name__ == "__main__": parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024) parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp") parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging") + parser.add_argument("--use_fused_adam", action="store_true", help="Use fused adam") args=parser.parse_args() @@ -107,4 +110,5 @@ if __name__ == "__main__": seq_len=args.seq_len, exp_name=args.exp_name, use_wandb=args.use_wandb, + use_fused_adam=args.use_fused_adam ) diff --git a/template/base_config.json b/template/base_config.json index 5d86fe3..625595a 100644 --- a/template/base_config.json +++ b/template/base_config.json @@ -4,6 +4,7 @@ "cp_size": 1, "pp_size": 1, "dp_size": 1, + "pp_engine": "afab", "master_addr": "localhost", "master_port": 29500, "backend": "nccl", @@ -15,7 +16,8 @@ "num_attention_heads": 16, "num_key_value_heads": 4, "dtype": "bfloat16", - "use_flash_attention": true + "use_flash_attention": true, + "use_fused_adam": true }, "training": { "seed": 42, diff --git a/train.py b/train.py index 3366bec..b8295a7 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- -- CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 """ - +import inspect import os import json import time @@ -184,8 +184,15 @@ if __name__ == "__main__": print("model to device time:", time.time()-start_time, is_print_rank=is_wandb_rank) tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size) - optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) + extra_args = dict() + if config["model"]["use_fused_adam"]: + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device == 'cuda' + extra_args = dict(fused=True) if use_fused else dict() + + optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args) + trained_tokens, step = 0, 0 if LOAD_PATH: step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)