small change for dataloader arguments
This commit is contained in:
parent
55efb321f9
commit
fc3b50b033
@ -8,7 +8,7 @@ from transformers import AutoTokenizer
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True):
|
||||
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, subset_name=None, split="train", num_samples=None, pin_memory=True):
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.seq_length = seq_length
|
||||
self.grad_acc_steps = grad_acc_steps
|
||||
@ -18,7 +18,9 @@ class MicroBatchDataLoader(DataLoader):
|
||||
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
|
||||
self.dataset = load_dataset(dataset_name, name=subset_name, split=split)
|
||||
|
||||
if num_samples:
|
||||
self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
|
||||
|
||||
8
train.py
8
train.py
@ -84,13 +84,15 @@ if __name__ == "__main__":
|
||||
SEQ_LEN = config["training"]["seq_length"]
|
||||
MICRO_BATCH_SIZE = config["training"]["micro_batch_size"]
|
||||
LEARNING_RATE = config["training"]["learning_rate"]
|
||||
NUM_SAMPLES = config["training"]["num_samples"]
|
||||
MAX_TOKENS = config["training"]["max_tokens"]
|
||||
SEED = config["training"]["seed"]
|
||||
TOTAL_TRAIN_STEPS = config["training"]["total_train_steps"]
|
||||
GRAD_ACC_STEPS = config["training"]["gradient_accumulation_steps"]
|
||||
MODEL_NAME = config["model"]["name"]
|
||||
DATASET_NAME = config["dataset"]["name"]
|
||||
SUBSET_NAME = config["dataset"].get("subset_name", None)
|
||||
SPLIT = config["dataset"].get("split", "train")
|
||||
NUM_SAMPLES = config["dataset"].get("num_samples", None)
|
||||
NUM_WORKERS = config["dataset"]["num_workers"]
|
||||
NUM_PROC = config["dataset"]["num_proc"]
|
||||
USE_WANDB = config["logging"]["use_wandb"]
|
||||
@ -133,7 +135,9 @@ if __name__ == "__main__":
|
||||
grad_acc_steps=GRAD_ACC_STEPS,
|
||||
num_workers=NUM_WORKERS,
|
||||
num_proc=NUM_PROC,
|
||||
num_samples=NUM_SAMPLES
|
||||
num_samples=NUM_SAMPLES,
|
||||
subset_name=SUBSET_NAME,
|
||||
split=SPLIT
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user