small change for dataloader arguments

This commit is contained in:
zzhhjjj 2024-12-18 15:55:55 +00:00
parent 55efb321f9
commit fc3b50b033
2 changed files with 10 additions and 4 deletions

View File

@ -8,7 +8,7 @@ from transformers import AutoTokenizer
import picotron.process_group_manager as pgm
class MicroBatchDataLoader(DataLoader):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, subset_name=None, split="train", num_samples=None, pin_memory=True):
self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.grad_acc_steps = grad_acc_steps
@ -18,7 +18,9 @@ class MicroBatchDataLoader(DataLoader):
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.dataset = load_dataset(dataset_name, split=split)
self.dataset = load_dataset(dataset_name, name=subset_name, split=split)
if num_samples:
self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))

View File

@ -84,13 +84,15 @@ if __name__ == "__main__":
SEQ_LEN = config["training"]["seq_length"]
MICRO_BATCH_SIZE = config["training"]["micro_batch_size"]
LEARNING_RATE = config["training"]["learning_rate"]
NUM_SAMPLES = config["training"]["num_samples"]
MAX_TOKENS = config["training"]["max_tokens"]
SEED = config["training"]["seed"]
TOTAL_TRAIN_STEPS = config["training"]["total_train_steps"]
GRAD_ACC_STEPS = config["training"]["gradient_accumulation_steps"]
MODEL_NAME = config["model"]["name"]
DATASET_NAME = config["dataset"]["name"]
SUBSET_NAME = config["dataset"].get("subset_name", None)
SPLIT = config["dataset"].get("split", "train")
NUM_SAMPLES = config["dataset"].get("num_samples", None)
NUM_WORKERS = config["dataset"]["num_workers"]
NUM_PROC = config["dataset"]["num_proc"]
USE_WANDB = config["logging"]["use_wandb"]
@ -133,7 +135,9 @@ if __name__ == "__main__":
grad_acc_steps=GRAD_ACC_STEPS,
num_workers=NUM_WORKERS,
num_proc=NUM_PROC,
num_samples=NUM_SAMPLES
num_samples=NUM_SAMPLES,
subset_name=SUBSET_NAME,
split=SPLIT
)
dist.barrier()