accelerate dataset mapping
This commit is contained in:
parent
1ca7365506
commit
81726dfffe
3
train.py
3
train.py
@ -18,6 +18,7 @@ from parallel.data_parallel import DataParallel
|
||||
from parallel.context_parallel import ContextParallel
|
||||
from model import Llama
|
||||
import wandb
|
||||
import multiprocessing
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
|
||||
@ -32,7 +33,7 @@ class MicroBatchDataLoader(DataLoader):
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
dist.barrier()
|
||||
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"])
|
||||
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names, num_proc=multiprocessing.cpu_count()).with_format("torch", columns=["input_ids"])
|
||||
|
||||
self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user