diff --git a/create_config.py b/create_config.py index ce4740e..f3e53cc 100644 --- a/create_config.py +++ b/create_config.py @@ -124,6 +124,7 @@ def create_single_config( grad_acc_steps: int, mbs: int, seq_len: int, + subset_name: Optional[str], exp_name: str, use_wandb: bool = False, use_cpu: bool = False, @@ -142,6 +143,7 @@ def create_single_config( config_content["environment"]["HF_TOKEN"] = hf_token config_content["training"]["seq_length"] = seq_len config_content["checkpoint"]["save_dir"] = run_path + config_content["dataset"]["subset_name"] = subset_name config_content["model"]["name"] = model_name @@ -195,6 +197,7 @@ if __name__ == "__main__": parser.add_argument("--grad_acc_steps", type=int, help="grad accumulation", default=1) parser.add_argument("--mbs", type=int, help="micro batch size", default=1) parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024) + parser.add_argument("--subset_name", type=str, help="Subset name", default=None) parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp") parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging") parser.add_argument("--use_cpu", action="store_true", help="Use CPU for training") @@ -217,6 +220,7 @@ if __name__ == "__main__": grad_acc_steps=args.grad_acc_steps, mbs=args.mbs, seq_len=args.seq_len, + subset_name=args.subset_name, exp_name=args.exp_name, use_wandb=args.use_wandb, use_cpu=args.use_cpu, diff --git a/picotron/data.py b/picotron/data.py index 6bf477f..74f6b3c 100644 --- a/picotron/data.py +++ b/picotron/data.py @@ -10,7 +10,7 @@ from picotron.utils import print import picotron.process_group_manager as pgm class MicroBatchDataLoader(DataLoader): - def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, split="train", num_samples=None, pin_memory=True): + def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, subset_name=None, split="train", num_samples=None, pin_memory=True): self.micro_batch_size = micro_batch_size self.seq_length = seq_length self.grad_acc_steps = grad_acc_steps diff --git a/template/base_config.json b/template/base_config.json index 7713767..e9ecf23 100644 --- a/template/base_config.json +++ b/template/base_config.json @@ -29,6 +29,7 @@ }, "dataset": { "name": "roneneldan/TinyStories", + "subset_name": null, "num_workers": 0, "num_proc": 1 }, diff --git a/train.py b/train.py index 605f30c..7f61b16 100644 --- a/train.py +++ b/train.py @@ -111,7 +111,8 @@ if __name__ == "__main__": device=device, num_workers=config["dataset"]["num_workers"], num_proc=config["dataset"]["num_proc"], - num_samples=config["training"]["num_samples"] + num_samples=config["training"]["num_samples"], + subset_name=config["dataset"]["subset_name"], ) dist.barrier()