diff --git a/train.py b/train.py index 7ee60dc..4a7fca4 100644 --- a/train.py +++ b/train.py @@ -18,6 +18,7 @@ from parallel.data_parallel import DataParallel from parallel.context_parallel import ContextParallel from model import Llama import wandb +import multiprocessing class MicroBatchDataLoader(DataLoader): def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): @@ -32,7 +33,7 @@ class MicroBatchDataLoader(DataLoader): self.dataset = load_dataset(dataset_name, split=split) if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) dist.barrier() - self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"]) + self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names, num_proc=multiprocessing.cpu_count()).with_format("torch", columns=["input_ids"]) self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False)