Merge branch 'main' into loading_big_model

This commit is contained in:
ferdinand.mom 2024-12-18 17:02:48 +00:00
commit 7daefd31ee
4 changed files with 8 additions and 2 deletions

View File

@ -124,6 +124,7 @@ def create_single_config(
grad_acc_steps: int,
mbs: int,
seq_len: int,
subset_name: Optional[str],
exp_name: str,
use_wandb: bool = False,
use_cpu: bool = False,
@ -142,6 +143,7 @@ def create_single_config(
config_content["environment"]["HF_TOKEN"] = hf_token
config_content["training"]["seq_length"] = seq_len
config_content["checkpoint"]["save_dir"] = run_path
config_content["dataset"]["subset_name"] = subset_name
config_content["model"]["name"] = model_name
@ -195,6 +197,7 @@ if __name__ == "__main__":
parser.add_argument("--grad_acc_steps", type=int, help="grad accumulation", default=1)
parser.add_argument("--mbs", type=int, help="micro batch size", default=1)
parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024)
parser.add_argument("--subset_name", type=str, help="Subset name", default=None)
parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp")
parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging")
parser.add_argument("--use_cpu", action="store_true", help="Use CPU for training")
@ -217,6 +220,7 @@ if __name__ == "__main__":
grad_acc_steps=args.grad_acc_steps,
mbs=args.mbs,
seq_len=args.seq_len,
subset_name=args.subset_name,
exp_name=args.exp_name,
use_wandb=args.use_wandb,
use_cpu=args.use_cpu,

View File

@ -10,7 +10,7 @@ from picotron.utils import print
import picotron.process_group_manager as pgm
class MicroBatchDataLoader(DataLoader):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, split="train", num_samples=None, pin_memory=True):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, subset_name=None, split="train", num_samples=None, pin_memory=True):
self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.grad_acc_steps = grad_acc_steps

View File

@ -29,6 +29,7 @@
},
"dataset": {
"name": "roneneldan/TinyStories",
"subset_name": null,
"num_workers": 0,
"num_proc": 1
},

View File

@ -111,7 +111,8 @@ if __name__ == "__main__":
device=device,
num_workers=config["dataset"]["num_workers"],
num_proc=config["dataset"]["num_proc"],
num_samples=config["training"]["num_samples"]
num_samples=config["training"]["num_samples"],
subset_name=config["dataset"]["subset_name"],
)
dist.barrier()