From 9e9ef8236e29f1b5258173457870cf37b2f5c93e Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 25 Sep 2024 13:17:05 +0000 Subject: [PATCH] refactor to decouple pp training with normal training --- data_parallel.py | 2 +- dataset.py | 31 +++++++++++++++++++ parallel_context.py | 2 +- pipeline_parallel.py | 10 ++----- train.py | 71 +++++++++++++++++++++++++++----------------- 5 files changed, 78 insertions(+), 38 deletions(-) create mode 100644 dataset.py diff --git a/data_parallel.py b/data_parallel.py index 6aad4d4..c7b40e3 100644 --- a/data_parallel.py +++ b/data_parallel.py @@ -3,7 +3,7 @@ import torch.nn as nn import parallel_context as pc class DataParallel(nn.Module): - def __init__(self, model): + def __init__(self, model, config): #TODO: Add Zero1 #TODO: Interleave all_reduce super().__init__() diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..39c3f9d --- /dev/null +++ b/dataset.py @@ -0,0 +1,31 @@ +import torch +import torch.distributed as dist +from transformers import AutoTokenizer +from torch.utils.data import DataLoader, DistributedSampler +from datasets import load_dataset + +import parallel_context as pc + +class MicroBatchDataLoader(DataLoader): + def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): + self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length + self.local_batch_size = self.global_batch_size // pc.parallel_context.dp_world_size + self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size + self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.dataset = load_dataset(dataset_name, split=split) + if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) + dist.barrier() + self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"]) + + self.sampler = DistributedSampler(self.dataset, num_replicas=pc.parallel_context.dp_world_size, rank=pc.parallel_context.dp_rank, shuffle=False) + + super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def collate_batch(self, batch_data): + batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) + batch_size, seq_len = batch_input_ids.shape + return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None} diff --git a/parallel_context.py b/parallel_context.py index c5c1871..a147559 100644 --- a/parallel_context.py +++ b/parallel_context.py @@ -13,7 +13,7 @@ class ParallelContext: self.dp_size = dp_size assert self.world_size == self.tp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})" - self.grid = torch.arange(self.world_size).view(self.tp_size, self.pp_size, self.dp_size,) # TP * PP * DP grid + self.grid = torch.arange(self.world_size).view(self.tp_size, self.pp_size, self.dp_size) # TP * PP * DP grid # Find the position of the current process in the grid self.tp_rank, self.pp_rank, self.dp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() diff --git a/pipeline_parallel.py b/pipeline_parallel.py index fa39659..15c21dd 100644 --- a/pipeline_parallel.py +++ b/pipeline_parallel.py @@ -38,7 +38,7 @@ class PipelineParallel(nn.Module): torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False) return input_tensor.grad if input_tensor is not None else None -def pipeline_parallel_afab(model, data_loader, tensor_shapes, device): +def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): logging_loss: torch.float32 = 0.0 input_tensors, output_tensors = [], [] @@ -63,13 +63,10 @@ def pipeline_parallel_afab(model, data_loader, tensor_shapes, device): input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) communicate(operation='send_backward', tensor=input_tensor_grad) - # Average gradient across DP ranks - model.all_reduce_gradients() - logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss -def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device): +def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): num_warmup_microbatches = min(pc.parallel_context.pp_world_size - pc.parallel_context.pp_rank - 1, data_loader.num_local_micro_batches) num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches logging_loss, input_tensors, output_tensors = 0.0, [], [] @@ -114,8 +111,5 @@ def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device): input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) communicate(operation='send_backward', tensor=input_tensor_grad) - # Average gradient across DP ranks - model.all_reduce_gradients() - logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss \ No newline at end of file diff --git a/train.py b/train.py index 8c2dfc9..ada6467 100644 --- a/train.py +++ b/train.py @@ -1,42 +1,41 @@ #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 import os +import torch.nn.functional as F import torch, torch.distributed as dist from torch.optim import AdamW -from torch.utils.data import DataLoader, DistributedSampler -from datasets import load_dataset -from transformers import AutoConfig, AutoModelForCausalLM,AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM import argparse import parallel_context as pc from utils import set_all_seed, display_parallelism_grid from parallel_context import setup_parallel_context -from pipeline_parallel import pipeline_parallel_1f1b, pipeline_parallel_afab, PipelineParallel +from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from data_parallel import DataParallel +from dataset import MicroBatchDataLoader -class MicroBatchDataLoader(DataLoader): - def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): - self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length - self.local_batch_size = self.global_batch_size // pc.parallel_context.dp_world_size - self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size - self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - self.dataset = load_dataset(dataset_name, split=split) - if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) - dist.barrier() - self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"]) +def train_step(model, data_loader, device): + total_loss = 0.0 + + for _ in range(data_loader.num_local_micro_batches): + batch = next(iter(data_loader)) - self.sampler = DistributedSampler(self.dataset, num_replicas=pc.parallel_context.dp_world_size, rank=pc.parallel_context.dp_rank, shuffle=False) - - super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False) + input_ids = batch["input_ids"].to(device) + position_ids = batch["position_index"].to(device) + target_ids = batch["target_ids"].to(device) - def set_epoch(self, epoch): - self.sampler.set_epoch(epoch) + outputs = model(input_ids=input_ids, position_ids=position_ids) + logits = outputs.logits - def collate_batch(self, batch_data): - batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) - batch_size, seq_len = batch_input_ids.shape - return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None} + # Use your suggested cross_entropy calculation + loss = F.cross_entropy(logits.transpose(1, 2), target_ids, reduction='mean') + + loss.backward() + + total_loss += loss.item() + + avg_loss = total_loss / data_loader.num_local_micro_batches + return avg_loss if __name__ == "__main__": @@ -64,10 +63,13 @@ if __name__ == "__main__": set_all_seed(seed=42) model_name = "HuggingFaceTB/SmolLM-360M-Instruct" config = AutoConfig.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name, config=config) + model = AutoModelForCausalLM.from_pretrained(model_name, config=config).to(device) - model = PipelineParallel(model, config).to(device) - model = DataParallel(model).to(device) + if pc.parallel_context.pp_world_size > 1: + model = PipelineParallel(model, config).to(device) + + if pc.parallel_context.dp_world_size > 1: + model = DataParallel(model, config).to(device) model.train() @@ -79,12 +81,25 @@ if __name__ == "__main__": tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN dist.barrier() + + #TODO: find a way to setup reference model training + #TODO: Add activation checkpointing + #TODO: add gradient accumulation while trained_tokens < MAX_TOKENS: data_loader.set_epoch(step) optimizer.zero_grad() - loss = pipeline_parallel_afab(model, data_loader, tensor_shapes, device) + + if pc.parallel_context.pp_world_size > 1: + loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device) + else: + loss = train_step(model, data_loader, device) + + if pc.parallel_context.dp_world_size > 1: + # Average gradient across DP ranks + model.all_reduce_gradients() + optimizer.step() trained_tokens += tokens_per_step step += 1