diff --git a/bench/create_config.py b/bench/create_config.py index 829ffc9..2522414 100644 --- a/bench/create_config.py +++ b/bench/create_config.py @@ -1,6 +1,6 @@ """ -python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct +python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct --num_attention_heads 16 --num_key_value_heads 4 --grad_acc 1 --mbs 32 --seq_len 4096 --use_wandb """ from copy import deepcopy from transformers import AutoConfig @@ -24,6 +24,7 @@ def create_single_config( mbs: int, seq_len: int, exp_name: str, + use_wandb: bool = False ): run_path = os.path.join(out_dir, exp_name) @@ -50,6 +51,9 @@ def create_single_config( config_content['distributed']['pp_size'] = pp config_content['distributed']['dp_size'] = dp + config_content['logging']['use_wandb'] = use_wandb + config_content['logging']['run_name'] = exp_name + gbs = dp * mbs * grad_acc gbs_token = gbs * seq_len print(f"Gbs_token: {gbs_token:,}, Gbs: {gbs}, dp: {dp}, seq_len: {seq_len}, grad_acc: {grad_acc}, mbs: {mbs}") @@ -80,7 +84,8 @@ if __name__ == "__main__": parser.add_argument("--mbs", type=int, help="micro batch size", default=1) parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024) parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp") - + parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging") + args=parser.parse_args() create_single_config( @@ -97,4 +102,5 @@ if __name__ == "__main__": mbs=args.mbs, seq_len=args.seq_len, exp_name=args.exp_name, + use_wandb=args.use_wandb, ) diff --git a/train.py b/train.py index 4bb41c3..c9e99b4 100644 --- a/train.py +++ b/train.py @@ -146,7 +146,7 @@ if __name__ == "__main__": if is_wandb_rank and USE_WANDB: wandb.init( project="picotron", - name=f"test_convergence_GBS_{tokens_per_step}_{pgm.process_group_manager}", + name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}", config={ "tensor_parallel_size": pgm.process_group_manager.tp_size, "context_parallel_size": pgm.process_group_manager.cp_size,