add wandb

eaezaeea
This commit is contained in:
ferdinand.mom 2024-10-30 14:25:10 +00:00
parent 3c635092f9
commit fdf2df8344
2 changed files with 9 additions and 3 deletions

View File

@ -1,6 +1,6 @@
"""
python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct
python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct --num_attention_heads 16 --num_key_value_heads 4 --grad_acc 1 --mbs 32 --seq_len 4096 --use_wandb
"""
from copy import deepcopy
from transformers import AutoConfig
@ -24,6 +24,7 @@ def create_single_config(
mbs: int,
seq_len: int,
exp_name: str,
use_wandb: bool = False
):
run_path = os.path.join(out_dir, exp_name)
@ -50,6 +51,9 @@ def create_single_config(
config_content['distributed']['pp_size'] = pp
config_content['distributed']['dp_size'] = dp
config_content['logging']['use_wandb'] = use_wandb
config_content['logging']['run_name'] = exp_name
gbs = dp * mbs * grad_acc
gbs_token = gbs * seq_len
print(f"Gbs_token: {gbs_token:,}, Gbs: {gbs}, dp: {dp}, seq_len: {seq_len}, grad_acc: {grad_acc}, mbs: {mbs}")
@ -80,7 +84,8 @@ if __name__ == "__main__":
parser.add_argument("--mbs", type=int, help="micro batch size", default=1)
parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024)
parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp")
parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging")
args=parser.parse_args()
create_single_config(
@ -97,4 +102,5 @@ if __name__ == "__main__":
mbs=args.mbs,
seq_len=args.seq_len,
exp_name=args.exp_name,
use_wandb=args.use_wandb,
)

View File

@ -146,7 +146,7 @@ if __name__ == "__main__":
if is_wandb_rank and USE_WANDB:
wandb.init(
project="picotron",
name=f"test_convergence_GBS_{tokens_per_step}_{pgm.process_group_manager}",
name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}",
config={
"tensor_parallel_size": pgm.process_group_manager.tp_size,
"context_parallel_size": pgm.process_group_manager.cp_size,