add wandb
eaezaeea
This commit is contained in:
parent
3c635092f9
commit
fdf2df8344
@ -1,6 +1,6 @@
|
||||
|
||||
"""
|
||||
python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct
|
||||
python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct --num_attention_heads 16 --num_key_value_heads 4 --grad_acc 1 --mbs 32 --seq_len 4096 --use_wandb
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from transformers import AutoConfig
|
||||
@ -24,6 +24,7 @@ def create_single_config(
|
||||
mbs: int,
|
||||
seq_len: int,
|
||||
exp_name: str,
|
||||
use_wandb: bool = False
|
||||
):
|
||||
run_path = os.path.join(out_dir, exp_name)
|
||||
|
||||
@ -50,6 +51,9 @@ def create_single_config(
|
||||
config_content['distributed']['pp_size'] = pp
|
||||
config_content['distributed']['dp_size'] = dp
|
||||
|
||||
config_content['logging']['use_wandb'] = use_wandb
|
||||
config_content['logging']['run_name'] = exp_name
|
||||
|
||||
gbs = dp * mbs * grad_acc
|
||||
gbs_token = gbs * seq_len
|
||||
print(f"Gbs_token: {gbs_token:,}, Gbs: {gbs}, dp: {dp}, seq_len: {seq_len}, grad_acc: {grad_acc}, mbs: {mbs}")
|
||||
@ -80,7 +84,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--mbs", type=int, help="micro batch size", default=1)
|
||||
parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024)
|
||||
parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp")
|
||||
|
||||
parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging")
|
||||
|
||||
args=parser.parse_args()
|
||||
|
||||
create_single_config(
|
||||
@ -97,4 +102,5 @@ if __name__ == "__main__":
|
||||
mbs=args.mbs,
|
||||
seq_len=args.seq_len,
|
||||
exp_name=args.exp_name,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
2
train.py
2
train.py
@ -146,7 +146,7 @@ if __name__ == "__main__":
|
||||
if is_wandb_rank and USE_WANDB:
|
||||
wandb.init(
|
||||
project="picotron",
|
||||
name=f"test_convergence_GBS_{tokens_per_step}_{pgm.process_group_manager}",
|
||||
name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}",
|
||||
config={
|
||||
"tensor_parallel_size": pgm.process_group_manager.tp_size,
|
||||
"context_parallel_size": pgm.process_group_manager.cp_size,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user