From 243f088170ce6d072af838e24df3a01d306b71c2 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 15:36:01 +0000 Subject: [PATCH] separate dataloader from utils to data.py --- picotron/data.py | 123 +++++++++++++++++++++++++++++++++++++++++++++ picotron/utils.py | 124 +--------------------------------------------- train.py | 3 +- 3 files changed, 127 insertions(+), 123 deletions(-) create mode 100644 picotron/data.py diff --git a/picotron/data.py b/picotron/data.py new file mode 100644 index 0000000..c6d4392 --- /dev/null +++ b/picotron/data.py @@ -0,0 +1,123 @@ +import torch +from torch.utils.data import DataLoader, DistributedSampler +import numpy as np +from functools import partial +from datasets import Features, Sequence, Value, load_dataset +from transformers import AutoTokenizer + +import picotron.process_group_manager as pgm + +class MicroBatchDataLoader(DataLoader): + def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None): + + self.micro_batch_size = micro_batch_size + self.seq_length = seq_length + self.grad_acc_steps = grad_acc_steps + self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size + self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + + self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.dataset = load_dataset(dataset_name, split=split) + if num_samples: + self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) + + # Tokenize and chunk the dataset + self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc) + + self.sampler = DistributedSampler( + self.tokenized_dataset, + num_replicas=pgm.process_group_manager.dp_world_size, + rank=pgm.process_group_manager.dp_rank, + shuffle=False + ) + + super().__init__( + self.tokenized_dataset, + batch_size=micro_batch_size, + collate_fn=self.collate_batch, + pin_memory=True, + num_workers=num_workers, + sampler=self.sampler, + shuffle=False + ) + + @staticmethod + def tokenizer_group_text(examples, tokenizer, sequence_length): + """Tokenize a list of texts and group them in chunks of sequence_length + 1""" + tokenized_text_batch = tokenizer.batch_encode_plus( + examples, + return_attention_mask=False, + return_token_type_ids=False, + return_tensors='np' + ) + concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} + total_length = len(concatenated_tokens['input_ids']) + if total_length >= sequence_length + 1: + total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 + result = { + 'input_ids': [ + concatenated_tokens['input_ids'][i : i + sequence_length + 1] + for i in range(0, total_length - sequence_length, sequence_length) + ] + } + return result + + def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): + """Tokenize the dataset and group texts in chunks of sequence_length + 1""" + # Create a partial function with fixed arguments + tokenizer_func = partial( + self.tokenizer_group_text, + tokenizer=self.tokenizer, + sequence_length=sequence_length + ) + + tokenized_dataset = dataset.map( + tokenizer_func, + input_columns=text_column_name, + remove_columns=dataset.column_names, + features=Features({ + "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) + }), + batched=True, + num_proc=num_proc, + load_from_cache_file=True, + desc=f"Grouping texts in chunks of {sequence_length+1}", + ) + + return tokenized_dataset + + def collate_batch(self, batch): + batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) + batch_size = batch_input_ids.size(0) + start_idx = pgm.process_group_manager.cp_rank * self.seq_length_per_gpu + end_idx = start_idx + self.seq_length_per_gpu + input_ids = batch_input_ids[:, start_idx:end_idx].contiguous() + target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous() + position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() + local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool)) + attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() + + return { + "input_ids": input_ids, + "target_ids": target_ids, + "position_ids": position_ids, + "attn_mask": attn_mask, + "hidden_states": None + } + + def __iter__(self): + if self._iterator is None: + self._iterator = super().__iter__() + return self + + def __next__(self): + if self._iterator is None: + self._iterator = super().__iter__() + try: + batch = next(self._iterator) + except StopIteration: + self._iterator = None + raise StopIteration + return batch \ No newline at end of file diff --git a/picotron/utils.py b/picotron/utils.py index dc63ea0..5dd100f 100644 --- a/picotron/utils.py +++ b/picotron/utils.py @@ -1,15 +1,10 @@ +import os import torch import random -import os import numpy as np import builtins import fcntl import picotron.process_group_manager as pgm -import torch -from torch.utils.data import DataLoader, DistributedSampler -from functools import partial -from datasets import Features, Sequence, Value, load_dataset -from transformers import AutoTokenizer def print(*args, is_print_rank=True, **kwargs): """ solves multi-process interleaved print problem """ @@ -72,119 +67,4 @@ def load_checkpoint(model, optimizer, out_dir): raw_model.load_state_dict(checkpoint['model']) # Load optimizer state optimizer.load_state_dict(checkpoint['optimizer']) - return checkpoint['trained_steps'], checkpoint['trained_tokens'] - -class MicroBatchDataLoader(DataLoader): - def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None): - - self.micro_batch_size = micro_batch_size - self.seq_length = seq_length - self.grad_acc_steps = grad_acc_steps - self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size - self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size - - self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size - - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - self.dataset = load_dataset(dataset_name, split=split) - if num_samples: - self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) - - # Tokenize and chunk the dataset - self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc) - - self.sampler = DistributedSampler( - self.tokenized_dataset, - num_replicas=pgm.process_group_manager.dp_world_size, - rank=pgm.process_group_manager.dp_rank, - shuffle=False - ) - - super().__init__( - self.tokenized_dataset, - batch_size=micro_batch_size, - collate_fn=self.collate_batch, - pin_memory=True, - num_workers=num_workers, - sampler=self.sampler, - shuffle=False - ) - - @staticmethod - def tokenizer_group_text(examples, tokenizer, sequence_length): - """Tokenize a list of texts and group them in chunks of sequence_length + 1""" - tokenized_text_batch = tokenizer.batch_encode_plus( - examples, - return_attention_mask=False, - return_token_type_ids=False, - return_tensors='np' - ) - concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} - total_length = len(concatenated_tokens['input_ids']) - if total_length >= sequence_length + 1: - total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 - result = { - 'input_ids': [ - concatenated_tokens['input_ids'][i : i + sequence_length + 1] - for i in range(0, total_length - sequence_length, sequence_length) - ] - } - return result - - def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): - """Tokenize the dataset and group texts in chunks of sequence_length + 1""" - # Create a partial function with fixed arguments - tokenizer_func = partial( - self.tokenizer_group_text, - tokenizer=self.tokenizer, - sequence_length=sequence_length - ) - - tokenized_dataset = dataset.map( - tokenizer_func, - input_columns=text_column_name, - remove_columns=dataset.column_names, - features=Features({ - "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) - }), - batched=True, - num_proc=num_proc, - load_from_cache_file=True, - desc=f"Grouping texts in chunks of {sequence_length+1}", - ) - - return tokenized_dataset - - def collate_batch(self, batch): - batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) - batch_size = batch_input_ids.size(0) - start_idx = pgm.process_group_manager.cp_rank * self.seq_length_per_gpu - end_idx = start_idx + self.seq_length_per_gpu - input_ids = batch_input_ids[:, start_idx:end_idx].contiguous() - target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous() - position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() - local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool)) - attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() - - return { - "input_ids": input_ids, - "target_ids": target_ids, - "position_ids": position_ids, - "attn_mask": attn_mask, - "hidden_states": None - } - - def __iter__(self): - if self._iterator is None: - self._iterator = super().__iter__() - return self - - def __next__(self): - if self._iterator is None: - self._iterator = super().__iter__() - try: - batch = next(self._iterator) - except StopIteration: - self._iterator = None - raise StopIteration - return batch \ No newline at end of file + return checkpoint['trained_steps'], checkpoint['trained_tokens'] \ No newline at end of file diff --git a/train.py b/train.py index a7898e3..96cd14d 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,8 @@ from transformers import AutoConfig import numpy as np from picotron.parallel.tensor_parallel.tensor_parallel import TensorParallel import picotron.process_group_manager as pgm -from picotron.parallel.utils import MicroBatchDataLoader, set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint +from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint +from picotron.data import MicroBatchDataLoader from picotron.process_group_manager import setup_process_group_manager from picotron.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from picotron.parallel.data_parallel.data_parallel_bucket import DataParallel