From 24ff8d05fd1731ae5a2d1319cffe335ead64f0ed Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Wed, 16 Oct 2024 16:48:55 +0000 Subject: [PATCH] add DDP --- src/parallel/data_parallel/__init__.py | 0 src/parallel/data_parallel/bucket.py | 173 ++++++++++++++++++ src/parallel/data_parallel/data_parallel.py | 54 ++++++ .../data_parallel/data_parallel_bucket.py | 119 ++++++++++++ train.py | 4 + utils.py | 2 +- 6 files changed, 351 insertions(+), 1 deletion(-) create mode 100644 src/parallel/data_parallel/__init__.py create mode 100644 src/parallel/data_parallel/bucket.py create mode 100644 src/parallel/data_parallel/data_parallel.py create mode 100644 src/parallel/data_parallel/data_parallel_bucket.py diff --git a/src/parallel/data_parallel/__init__.py b/src/parallel/data_parallel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/parallel/data_parallel/bucket.py b/src/parallel/data_parallel/bucket.py new file mode 100644 index 0000000..450e488 --- /dev/null +++ b/src/parallel/data_parallel/bucket.py @@ -0,0 +1,173 @@ +from typing import List +import torch +import torch.distributed as dist +from torch import nn + +class Bucket: + def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None: + """ + Initializes a Bucket instance. + + Args: + params (List[torch.nn.Parameter]): List of parameters assigned to this bucket. + grad_data (torch.Tensor): Tensor to store the gradients for this bucket. + process_group (torch.distributed.ProcessGroup): Process group used for synchronizing gradients. + """ + self.params = set(params) # Set of parameters in this bucket. + self.params_with_grad_ready = set() # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready + self.grad_data = grad_data # Tensor that stores gradients for all parameters in this bucket. + self.process_group = process_group # Process group for gradient synchronization. + self.process_group_size = dist.get_world_size(group=self.process_group) + self.handle = None # Handle for the async allreduce operation. + + self.reset() + + def sync_gradient(self) -> None: + """ + Launch an asynchronous all-reduce operation to synchronize gradients across processes. + """ + assert self.handle is None + self.handle = dist.all_reduce(self.grad_data, group=self.process_group, async_op=True) + self.grad_data /= self.process_group_size + + def reset(self) -> None: + """ + Reset the bucket to its initial state. Typically called after the gradient synchronization is finished. + """ + self.handle = None + self.params_with_grad_ready.clear() # Clear the set of parameters ready for gradient synchronization. + self.grad_data.zero_() # Zero the gradient tensor. + + def wait(self) -> None: + """ + wait for the allreduce operation to finish + """ + assert self.handle is not None, "You should launch an allreduce operation before waiting for it to finish" + self.handle.wait() # Block until the all-reduce operation finishes. + + def mark_param_as_ready(self, param: torch.nn.Parameter) -> None: + """ + Mark a parameter as ready for gradient synchronization. Launches synchronization when all parameters in the + bucket have their gradients ready. + """ + assert param in self.params and param not in self.params_with_grad_ready + self.params_with_grad_ready.add(param) + # When all parameters in the bucket have their gradients ready, synchronize gradients + if len(self.params_with_grad_ready) == len(self.params): + self.sync_gradient() + +class BucketManager: + def __init__(self, params: List[torch.nn.Parameter], process_group: torch.distributed.ProcessGroup, bucket_size: int, grad_type: torch.dtype = torch.float32) -> None: + """ + Initializes the BucketManager. + + Args: + params (List[torch.nn.Parameter]): List of model parameters. + process_group (torch.distributed.ProcessGroup): Process group used for gradient synchronization. + bucket_size (int): Maximum size of each bucket in terms of gradient elements. + grad_type (torch.dtype, optional): Data type of gradients, defaults to torch.float32. + """ + self.params = list(params) # Convert parameter generator to a list. + self.buckets = [] # List of buckets. + self.process_group = process_group + self.process_group_size = dist.get_world_size(group=self.process_group) + self.params_to_bucket_location = {} # Map each parameter to its corresponding bucket/place (start, end, bucket_idx). + self.bucket_size = bucket_size + self.bucket_sizes = None # Actual sizes of each bucket. + self.grad_data_list = [] # List of tensors to store gradients, one tensor per bucket. + self.grad_type = grad_type + # Divide gradients into buckets based on the provided bucket size. + self._initialize_buckets() + + + def _initialize_buckets(self) -> None: + """ + Divides model parameters into buckets for gradient synchronization based on the bucket size. + """ + cur_bucket_size = 0 + cur_bucket_idx = 0 + + # Assign parameters to buckets. + for param in self.params: + if not param.requires_grad: + continue + + # If the bucket is empty, add the parameter to the bucket. + if cur_bucket_size == 0: + self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx) + cur_bucket_size = param.numel() + continue + + # If the parameter cannot fit in the current bucket, create a new bucket + if cur_bucket_size + param.numel() > self.bucket_size: + cur_bucket_idx += 1 + self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx) + cur_bucket_size = param.numel() + else: + self.params_to_bucket_location[param] = (cur_bucket_size, cur_bucket_size + param.numel(), cur_bucket_idx) + cur_bucket_size += param.numel() + + # Gather information about the bucket sizes and the parameters in each bucket + bucket_sizes = [0] * (cur_bucket_idx + 1) + buckets_to_params = [[] for _ in range(cur_bucket_idx + 1)] + for param, (_, end, idx) in self.params_to_bucket_location.items(): + bucket_sizes[idx] = max(bucket_sizes[idx], end) + buckets_to_params[idx].append(param) + + # Create tensors for storing gradients and initialize Bucket objects. + for i in range(len(bucket_sizes)): + self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device='cuda')) + self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group)) + + # Create gradient views for each parameter. + for param in self.params[::-1]: + if not param.requires_grad: + continue + data_start_index, data_end_index, bucket_id = self.params_to_bucket_location[param] + # param.main_grad is used for gradient calculation + param.main_grad = self._get_view_from_tensor(self.grad_data_list[bucket_id], param.shape, data_start_index, data_end_index) + + def _get_view_from_tensor(self, tensor: torch.Tensor, shape: torch.Size, start: int, end: int) -> torch.Tensor: + """ + Create a view of the given tensor with the specified shape from start to end indices. + """ + return tensor[start:end].view(shape) + + def reset(self) -> None: + """ + Reset all buckets by clearing the gradients and internal states. + """ + for bucket in self.buckets: + bucket.reset() + + def wait(self) -> None: + """ + Wait for all buckets to complete their gradient synchronization + """ + for bucket in self.buckets: + bucket.wait() + + def mark_param_as_ready(self, param: torch.nn.Parameter) -> None: + """ + Mark a parameter's gradient as ready for synchronization. + """ + bucket_idx = self.params_to_bucket_location[param][2] + self.buckets[bucket_idx].mark_param_as_ready(param) + + + + + + + + + + + + + + + + + + diff --git a/src/parallel/data_parallel/data_parallel.py b/src/parallel/data_parallel/data_parallel.py new file mode 100644 index 0000000..dec0104 --- /dev/null +++ b/src/parallel/data_parallel/data_parallel.py @@ -0,0 +1,54 @@ +import contextlib +import torch +import torch.distributed as dist +from torch import nn + +class DataParallel(nn.Module): + def __init__(self, module, process_group): + """ + Initializes the DataParallel wrapper for a given module. + + Args: + module (nn.Module): The model to be wrapped for data parallelism. + process_group (torch.distributed.ProcessGroup): The process group used for gradient synchronization. + It could be a data parallel or context parallel group. + """ + super().__init__() + self.module = module + self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group + self.dp_world_size = dist.get_world_size(group=self.process_group) + self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation + self.register_backward_hook(self._allreduce_grads) + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def register_backward_hook(self, hook): + """ + Registers a backward hook for all parameters of the model that require gradients. + """ + for p in self.module.parameters(): + if p.requires_grad is True: + p.register_hook(hook) + + def _allreduce_grads(self, grad): + """ + Performs an all-reduce operation to synchronize gradients across multiple processes. + """ + # No synchronization needed during gradient accumulation, except at the final accumulation step. + # 324K tokens/s/gpu -> 334K tokens/s/gpu + if self.require_backward_grad_sync: + dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=self.process_group) + grad /= self.dp_world_size + return grad + + @contextlib.contextmanager + def no_sync(self): + """ + A context manager to temporarily disable gradient synchronization. + This is useful for performing multiple backward passes during gradient accumulation without synchronizing + gradients in between. + """ + self.require_backward_grad_sync = False + yield + self.require_backward_grad_sync = True \ No newline at end of file diff --git a/src/parallel/data_parallel/data_parallel_bucket.py b/src/parallel/data_parallel/data_parallel_bucket.py new file mode 100644 index 0000000..053598d --- /dev/null +++ b/src/parallel/data_parallel/data_parallel_bucket.py @@ -0,0 +1,119 @@ +import contextlib +import torch +import torch.distributed as dist +from torch import nn +from torch.autograd import Variable +from src.parallel.data_parallel.bucket import BucketManager + +class DataParallel(nn.Module): + def __init__(self, module, process_group, bucket_cap_mb=25, grad_type = torch.float32): + """ + Initialize the DataParallel module. + + Args: + module (nn.Module): The model to be parallelized. + process_group: The process group for gradient synchronization, which can be either + a data parallel group or a context parallel group. + bucket_cap_mb (int, optional): The maximum size of each gradient synchronization bucket in megabytes. + Defaults to 25 MB. + grad_type (torch.dtype, optional): The data type of gradients, defaulting to float32. + """ + super().__init__() + self.module = module + self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group + self.dp_world_size = dist.get_world_size(group=self.process_group) + self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation + grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes + bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket + self.bucket_manager = BucketManager(module.parameters(), self.process_group, bucket_size, grad_type) + self.register_backward_hook() + self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def backward(self, input_tensor, output_tensor, output_tensor_grad): + return self.module.backward(input_tensor, output_tensor, output_tensor_grad) + + def get_flops(self, *args, **kwargs): + return self.module.get_flops(*args, **kwargs) + + def register_backward_hook(self): + """ + Registers a backward hook to manually accumulate and synchronize gradients. + + This hook serves two main purposes: + 1. PyTorch does not natively support gradient accumulation with mixed precision. + 2. After gradient accumulation, it flags parameters as ready for synchronization. + + The gradient accumulation functions are stored to prevent them from going out of scope. + + References: + - https://github.com/NVIDIA/Megatron-LM/issues/690 + - https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html + - https://arxiv.org/abs/2006.15704 (page 5) + """ + self.grad_accs = [] + for param in self.module.parameters(): + if param.requires_grad: + # Expand so we get access to grad_fn. + param_tmp = param.expand_as(param) + # Get the gradient accumulator function. + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager)) + self.grad_accs.append(grad_acc) + + def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager): + """ + Creates the a hook for each parameter to handle gradient accumulation and synchronization. + """ + def param_hook(*unused): + """ + The hook called after the gradient is ready. It performs the following: + 1. Accumulates the gradient into the main gradient. + 2. Adds a post-backward callback to wait for gradient synchronization completion. + 3. Marks the parameter as ready for synchronization. + """ + if param.requires_grad: + assert param.grad is not None + param.main_grad.add_(param.grad.data) # accumulate the gradients + param.grad = None + + # skip the gradient synchronization (gradient accumulation/PP micro batches) + if self.require_backward_grad_sync: + # Add a callback to wait for gradient synchronization. Ensures the callback is added only once. + # Callback is executed after the backward pass. It should be added per backward pass. + if not self._post_backward_callback_set: + Variable._execution_engine.queue_callback(self._post_backward) + self._post_backward_callback_set = True + + # mark the parameter as ready for gradient synchronization. + bucket_manager.mark_param_as_ready(param) + return param_hook + + @contextlib.contextmanager + def no_sync(self): + """A context manager to disable gradient synchronization.""" + self.require_backward_grad_sync = False + yield + self.require_backward_grad_sync = True + + def _post_backward(self): + """ + A post-backward callback that waits for gradient synchronization to finish, then copies + the synchronized gradients back to the parameters' grad attribute. + + This method is called after the backward pass and before the optimizer step. + """ + self.bucket_manager.wait() + self._post_backward_callback_set = False + # copy to params.grad so we can use the optimizer to update the parameters + for p in self.module.parameters(): + if p.requires_grad: + p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type. + + def reset(self): + """ + Reset the bucket manager and zero out gradients in the model + """ + self.bucket_manager.reset() \ No newline at end of file diff --git a/train.py b/train.py index 3beb1a9..9df4252 100644 --- a/train.py +++ b/train.py @@ -265,6 +265,10 @@ if __name__ == "__main__": trained_tokens += tokens_per_step step += 1 + # In DDP implementation I need to reset the gradient buffers + if hasattr(model, 'reset'): + model.reset() + if pgm.process_group_manager.global_rank == 0: if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1: handle.wait() diff --git a/utils.py b/utils.py index b2dcb90..f172bf9 100644 --- a/utils.py +++ b/utils.py @@ -3,7 +3,7 @@ import random import numpy as np import builtins import fcntl -import distributed.process_group_manager as pgm +import src.distributed.process_group_manager as pgm def print(*args, **kwargs): """ solves multi-process interleaved print problem """