diff --git a/picotron/data.py b/picotron/data.py index c6d4392..fa14948 100644 --- a/picotron/data.py +++ b/picotron/data.py @@ -8,8 +8,7 @@ from transformers import AutoTokenizer import picotron.process_group_manager as pgm class MicroBatchDataLoader(DataLoader): - def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None): - + def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True): self.micro_batch_size = micro_batch_size self.seq_length = seq_length self.grad_acc_steps = grad_acc_steps @@ -37,7 +36,7 @@ class MicroBatchDataLoader(DataLoader): self.tokenized_dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, - pin_memory=True, + pin_memory=pin_memory, num_workers=num_workers, sampler=self.sampler, shuffle=False @@ -96,14 +95,11 @@ class MicroBatchDataLoader(DataLoader): input_ids = batch_input_ids[:, start_idx:end_idx].contiguous() target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous() position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() - local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool)) - attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() return { "input_ids": input_ids, "target_ids": target_ids, "position_ids": position_ids, - "attn_mask": attn_mask, "hidden_states": None } @@ -118,6 +114,12 @@ class MicroBatchDataLoader(DataLoader): try: batch = next(self._iterator) except StopIteration: - self._iterator = None - raise StopIteration + # Reinitialize the sampler and iterator + self.sampler.set_epoch(self.sampler.epoch + 1 if hasattr(self.sampler, 'epoch') else 0) + self._iterator = super().__iter__() + try: + batch = next(self._iterator) + except StopIteration: + self._iterator = None + raise StopIteration return batch \ No newline at end of file diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py new file mode 100644 index 0000000..f865486 --- /dev/null +++ b/tests/test_dataloader.py @@ -0,0 +1,213 @@ +""" +torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 test_dataloader.py +""" +from picotron.data import MicroBatchDataLoader +import torch.distributed as dist +import os +import datetime +from picotron.process_group_manager import setup_process_group_manager + +import torch +from torch.utils.data import DataLoader, DistributedSampler +import numpy as np +from functools import partial +from datasets import Features, Sequence, Value, load_dataset +from transformers import AutoTokenizer + +import picotron.process_group_manager as pgm + +# remove context parallelism split. as a reference +class DummyDataLoader(DataLoader): + def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True): + self.micro_batch_size = micro_batch_size + self.seq_length = seq_length + self.grad_acc_steps = grad_acc_steps + self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size + self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + + self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.dataset = load_dataset(dataset_name, split=split) + if num_samples: + self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) + + # Tokenize and chunk the dataset + self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc) + + self.sampler = DistributedSampler( + self.tokenized_dataset, + num_replicas=pgm.process_group_manager.dp_world_size, + rank=pgm.process_group_manager.dp_rank, + shuffle=False + ) + + super().__init__( + self.tokenized_dataset, + batch_size=micro_batch_size, + collate_fn=self.collate_batch, + pin_memory=True, + num_workers=num_workers, + sampler=self.sampler, + shuffle=False + ) + + @staticmethod + def tokenizer_group_text(examples, tokenizer, sequence_length): + """Tokenize a list of texts and group them in chunks of sequence_length + 1""" + tokenized_text_batch = tokenizer.batch_encode_plus( + examples, + return_attention_mask=False, + return_token_type_ids=False, + return_tensors='np' + ) + concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} + total_length = len(concatenated_tokens['input_ids']) + if total_length >= sequence_length + 1: + total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 + result = { + 'input_ids': [ + concatenated_tokens['input_ids'][i : i + sequence_length + 1] + for i in range(0, total_length - sequence_length, sequence_length) + ] + } + return result + + def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): + """Tokenize the dataset and group texts in chunks of sequence_length + 1""" + # Create a partial function with fixed arguments + tokenizer_func = partial( + self.tokenizer_group_text, + tokenizer=self.tokenizer, + sequence_length=sequence_length + ) + + tokenized_dataset = dataset.map( + tokenizer_func, + input_columns=text_column_name, + remove_columns=dataset.column_names, + features=Features({ + "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) + }), + batched=True, + num_proc=num_proc, + load_from_cache_file=True, + desc=f"Grouping texts in chunks of {sequence_length+1}", + ) + + return tokenized_dataset + + def collate_batch(self, batch): + batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) + batch_size = batch_input_ids.size(0) + input_ids = batch_input_ids[:, :self.seq_length].contiguous() + target_ids = batch_input_ids[:, 1:self.seq_length+1].contiguous() + position_ids = torch.arange(0, self.seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() + + return { + "input_ids": input_ids, + "target_ids": target_ids, + "position_ids": position_ids, + "hidden_states": None + } + + def __iter__(self): + if self._iterator is None: + self._iterator = super().__iter__() + return self + + def __next__(self): + if self._iterator is None: + self._iterator = super().__iter__() + try: + batch = next(self._iterator) + except StopIteration: + # Reinitialize the sampler and iterator + self.sampler.set_epoch(self.sampler.epoch + 1 if hasattr(self.sampler, 'epoch') else 0) + self._iterator = super().__iter__() + try: + batch = next(self._iterator) + except StopIteration: + self._iterator = None + raise StopIteration + return batch + +# test the tokens are split correctly in context parallelism +# TODO: test zigzag behavior +def test_cp_behavior(TP_SIZE, CP_SIZE, PP_SIZE, DP_SIZE, SEQ_LEN=8): + local_rank = int(os.environ["LOCAL_RANK"]) + global_rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + backend = "nccl" + + assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" + dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3)) + setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE) + + data_loader = MicroBatchDataLoader( + micro_batch_size=2, + seq_length=SEQ_LEN, + dataset_name="roneneldan/TinyStories", + tokenizer_name="HuggingFaceTB/SmolLM-135M", + grad_acc_steps=1, + num_workers=1, + num_proc=1, + num_samples=10, + pin_memory=False + ) + + ref_data_loader = DummyDataLoader( + micro_batch_size=2, + seq_length=SEQ_LEN, + dataset_name="roneneldan/TinyStories", + tokenizer_name="HuggingFaceTB/SmolLM-135M", + grad_acc_steps=1, + num_workers=1, + num_proc=1, + num_samples=10, + pin_memory=False + ) + + for i in range(1): + ref_batch = next(ref_data_loader) + batch = next(data_loader) + split_size = ref_batch["input_ids"].shape[1] // pgm.process_group_manager.cp_world_size + start_idx = split_size * global_rank + end_idx = start_idx + split_size + assert torch.equal(ref_batch["input_ids"][:,start_idx:end_idx], batch["input_ids"]), "input_ids are not equal" + +# test the infinite loop behavior +def test_infinite_loop(): + local_rank = 0 + global_rank = 0 + world_size = 1 + backend = "nccl" + + dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3)) + setup_process_group_manager(tp_size=1, cp_size=1, pp_size=1, dp_size=1) + + data_loader = MicroBatchDataLoader( + micro_batch_size=2, + seq_length=256, + dataset_name="roneneldan/TinyStories", + tokenizer_name="HuggingFaceTB/SmolLM-135M", + grad_acc_steps=1, + num_workers=1, + num_proc=1, + num_samples=2, + ) + + s = set() + for i in range(10): + batch = next(data_loader) + # Convert the nested list to a tuple of tuples + batch_tuple = tuple(tuple(x) for x in batch["input_ids"].tolist()) + if batch_tuple in s: + assert True + s.add(batch_tuple) + assert False + + +if __name__ == "__main__": + # test_infinite_loop() + test_cp_behavior(TP_SIZE=1, CP_SIZE=2, PP_SIZE=1, DP_SIZE=1, SEQ_LEN=8) \ No newline at end of file