diff --git a/picotron/data.py b/picotron/data.py index fa14948..b05e764 100644 --- a/picotron/data.py +++ b/picotron/data.py @@ -8,7 +8,7 @@ from transformers import AutoTokenizer import picotron.process_group_manager as pgm class MicroBatchDataLoader(DataLoader): - def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True): + def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, subset_name=None, split="train", num_samples=None, pin_memory=True): self.micro_batch_size = micro_batch_size self.seq_length = seq_length self.grad_acc_steps = grad_acc_steps @@ -18,7 +18,9 @@ class MicroBatchDataLoader(DataLoader): self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - self.dataset = load_dataset(dataset_name, split=split) + + self.dataset = load_dataset(dataset_name, name=subset_name, split=split) + if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) diff --git a/train.py b/train.py index fcf5ca2..589e521 100644 --- a/train.py +++ b/train.py @@ -84,13 +84,15 @@ if __name__ == "__main__": SEQ_LEN = config["training"]["seq_length"] MICRO_BATCH_SIZE = config["training"]["micro_batch_size"] LEARNING_RATE = config["training"]["learning_rate"] - NUM_SAMPLES = config["training"]["num_samples"] MAX_TOKENS = config["training"]["max_tokens"] SEED = config["training"]["seed"] TOTAL_TRAIN_STEPS = config["training"]["total_train_steps"] GRAD_ACC_STEPS = config["training"]["gradient_accumulation_steps"] MODEL_NAME = config["model"]["name"] DATASET_NAME = config["dataset"]["name"] + SUBSET_NAME = config["dataset"].get("subset_name", None) + SPLIT = config["dataset"].get("split", "train") + NUM_SAMPLES = config["dataset"].get("num_samples", None) NUM_WORKERS = config["dataset"]["num_workers"] NUM_PROC = config["dataset"]["num_proc"] USE_WANDB = config["logging"]["use_wandb"] @@ -133,7 +135,9 @@ if __name__ == "__main__": grad_acc_steps=GRAD_ACC_STEPS, num_workers=NUM_WORKERS, num_proc=NUM_PROC, - num_samples=NUM_SAMPLES + num_samples=NUM_SAMPLES, + subset_name=SUBSET_NAME, + split=SPLIT ) dist.barrier()