diff --git a/data_parallel.py b/data_parallel.py new file mode 100644 index 0000000..6aad4d4 --- /dev/null +++ b/data_parallel.py @@ -0,0 +1,24 @@ +import torch.distributed as dist +import torch.nn as nn +import parallel_context as pc + +class DataParallel(nn.Module): + def __init__(self, model): + #TODO: Add Zero1 + #TODO: Interleave all_reduce + super().__init__() + self.model = model + self.dp_world_size = pc.parallel_context.dp_world_size + self.dp_rank = pc.parallel_context.dp_rank + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def backward(self, input_tensor, output_tensor, output_tensor_grad): + return self.model.backward(input_tensor, output_tensor, output_tensor_grad) + + def all_reduce_gradients(self): + for param in self.model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pc.parallel_context.dp_group) + \ No newline at end of file diff --git a/generate.py b/generate.py index 424a0da..5fbc663 100644 --- a/generate.py +++ b/generate.py @@ -2,7 +2,7 @@ import os import argparse import torch, torch.distributed as dist -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM,AutoTokenizer from utils import set_all_seed import parallel_context as pc @@ -10,18 +10,18 @@ from parallel_context import setup_parallel_context from pipeline_parallel import PipelineParallel from distributed_primtives import communicate -def run_one_inference_step(model, batch, device) -> torch.Tensor: +def run_one_inference_step(model, batch, device, config) -> torch.Tensor: if pc.parallel_context.pp_world_size == 1: return model.forward(batch, device) batch_size = batch["input_ids"].shape[0] seq_len = batch["input_ids"].shape[1] - tensor_shapes = (batch_size, seq_len, model.config.hidden_size) + tensor_shapes = (batch_size, seq_len, config.hidden_size) # Preallocate memory for output logits. logits = None if pc.parallel_context.pp_is_last_stage: - logits = torch.empty((batch_size, seq_len, int(model.config.vocab_size)), dtype=torch.float32, device=device) + logits = torch.empty((batch_size, seq_len, int(config.vocab_size)), dtype=torch.float32, device=device) recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32) @@ -53,18 +53,22 @@ if __name__ == "__main__": device = torch.device("cuda", local_rank) setup_parallel_context(tp_size=1, pp_size=args.pp_size, dp_size=1) set_all_seed(seed=42) - model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) - + model_name = "HuggingFaceTB/SmolLM-360M-Instruct" + config = AutoConfig.from_pretrained(model_name) + base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) + model = PipelineParallel(base_model, config).to(device) + del base_model + model.eval() # Tokenize the input prompts = [ "My name is", - # "How old are you ?", - # "What is your favorite color?", + "How old are you ?", + "What is your favorite color?", ] - tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M-Instruct") + tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.padding_side = "left" tokenizer.pad_token = tokenizer.eos_token @@ -85,7 +89,7 @@ if __name__ == "__main__": "hidden_states": None, } - logits = run_one_inference_step(model, batch_prompts, device) + logits = run_one_inference_step(model, batch_prompts, device, config) # Sample new token if pc.parallel_context.pp_is_last_stage: diff --git a/pipeline_parallel.py b/pipeline_parallel.py index 63c3c34..fa39659 100644 --- a/pipeline_parallel.py +++ b/pipeline_parallel.py @@ -1,4 +1,3 @@ -from transformers import AutoConfig, AutoModelForCausalLM import parallel_context as pc from distributed_primtives import communicate, bidirectional_communicate import torch, torch.nn as nn, torch.nn.functional as F @@ -11,16 +10,13 @@ def reduce_loss_across_dp_ranks(loss, device): return reduced_loss.item() class PipelineParallel(nn.Module): - def __init__(self, model_name): + def __init__(self, model, config): super().__init__() - self.config = AutoConfig.from_pretrained(model_name) - base_model = AutoModelForCausalLM.from_pretrained(model_name, config=self.config) - layer_distribution = self.distribute_layers(self.config.num_hidden_layers) - self.embed_tokens = base_model.model.embed_tokens if pc.parallel_context.pp_is_first_stage else nn.Identity() - self.decoder_layers = nn.ModuleDict({str(i): base_model.model.layers[i] for i in layer_distribution}) - self.norm = base_model.model.norm if pc.parallel_context.pp_is_last_stage else nn.Identity() - self.lm_head = base_model.lm_head if pc.parallel_context.pp_is_last_stage else nn.Identity() - del base_model + layer_distribution = self.distribute_layers(config.num_hidden_layers) + self.embed_tokens = model.model.embed_tokens if pc.parallel_context.pp_is_first_stage else nn.Identity() + self.decoder_layers = nn.ModuleDict({str(i): model.model.layers[i] for i in layer_distribution}) + self.norm = model.model.norm if pc.parallel_context.pp_is_last_stage else nn.Identity() + self.lm_head = model.lm_head if pc.parallel_context.pp_is_last_stage else nn.Identity() def distribute_layers(self, num_layers): layers_per_gpu = [num_layers // pc.parallel_context.pp_world_size + (1 if i < num_layers % pc.parallel_context.pp_world_size else 0) for i in range(pc.parallel_context.pp_world_size)] @@ -67,6 +63,9 @@ def pipeline_parallel_afab(model, data_loader, tensor_shapes, device): input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) communicate(operation='send_backward', tensor=input_tensor_grad) + # Average gradient across DP ranks + model.all_reduce_gradients() + logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss @@ -115,5 +114,8 @@ def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device): input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) communicate(operation='send_backward', tensor=input_tensor_grad) + # Average gradient across DP ranks + model.all_reduce_gradients() + logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss \ No newline at end of file diff --git a/train.py b/train.py index ff73270..8c2dfc9 100644 --- a/train.py +++ b/train.py @@ -4,18 +4,20 @@ import torch, torch.distributed as dist from torch.optim import AdamW from torch.utils.data import DataLoader, DistributedSampler from datasets import load_dataset -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM,AutoTokenizer + import argparse import parallel_context as pc from utils import set_all_seed, display_parallelism_grid from parallel_context import setup_parallel_context from pipeline_parallel import pipeline_parallel_1f1b, pipeline_parallel_afab, PipelineParallel +from data_parallel import DataParallel class MicroBatchDataLoader(DataLoader): - def __init__(self, global_batch_size, micro_batch_size, data_parallel_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): - self.global_batch_size, self.micro_batch_size, self.data_parallel_size, self.seq_length = global_batch_size, micro_batch_size, data_parallel_size, seq_length - self.local_batch_size = self.global_batch_size // self.data_parallel_size + def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): + self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length + self.local_batch_size = self.global_batch_size // pc.parallel_context.dp_world_size self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) @@ -23,7 +25,13 @@ class MicroBatchDataLoader(DataLoader): if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) dist.barrier() self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"]) - super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=DistributedSampler(self.dataset, num_replicas=data_parallel_size, rank=0, shuffle=False), shuffle=False) + + self.sampler = DistributedSampler(self.dataset, num_replicas=pc.parallel_context.dp_world_size, rank=pc.parallel_context.dp_rank, shuffle=False) + + super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) def collate_batch(self, batch_data): batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) @@ -54,14 +62,27 @@ if __name__ == "__main__": display_parallelism_grid() set_all_seed(seed=42) - model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) - data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, 1, SEQ_LEN, "roneneldan/TinyStories", "HuggingFaceTB/SmolLM-360M-Instruct", num_samples=NUM_SAMPLES) - tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, model.config.hidden_size) + model_name = "HuggingFaceTB/SmolLM-360M-Instruct" + config = AutoConfig.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, config=config) + + model = PipelineParallel(model, config).to(device) + model = DataParallel(model).to(device) + + model.train() + + data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, "roneneldan/TinyStories", model_name, num_samples=NUM_SAMPLES) + tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, config.hidden_size) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) + trained_tokens, step = 0, 0 tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN + + dist.barrier() - while trained_tokens < MAX_TOKENS: + while trained_tokens < MAX_TOKENS: + data_loader.set_epoch(step) + optimizer.zero_grad() loss = pipeline_parallel_afab(model, data_loader, tensor_shapes, device) optimizer.step()