This commit is contained in:
zzhhjjj 2024-10-16 16:48:55 +00:00
parent 5139a32211
commit 24ff8d05fd
6 changed files with 351 additions and 1 deletions

View File

View File

@ -0,0 +1,173 @@
from typing import List
import torch
import torch.distributed as dist
from torch import nn
class Bucket:
def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None:
"""
Initializes a Bucket instance.
Args:
params (List[torch.nn.Parameter]): List of parameters assigned to this bucket.
grad_data (torch.Tensor): Tensor to store the gradients for this bucket.
process_group (torch.distributed.ProcessGroup): Process group used for synchronizing gradients.
"""
self.params = set(params) # Set of parameters in this bucket.
self.params_with_grad_ready = set() # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
self.grad_data = grad_data # Tensor that stores gradients for all parameters in this bucket.
self.process_group = process_group # Process group for gradient synchronization.
self.process_group_size = dist.get_world_size(group=self.process_group)
self.handle = None # Handle for the async allreduce operation.
self.reset()
def sync_gradient(self) -> None:
"""
Launch an asynchronous all-reduce operation to synchronize gradients across processes.
"""
assert self.handle is None
self.handle = dist.all_reduce(self.grad_data, group=self.process_group, async_op=True)
self.grad_data /= self.process_group_size
def reset(self) -> None:
"""
Reset the bucket to its initial state. Typically called after the gradient synchronization is finished.
"""
self.handle = None
self.params_with_grad_ready.clear() # Clear the set of parameters ready for gradient synchronization.
self.grad_data.zero_() # Zero the gradient tensor.
def wait(self) -> None:
"""
wait for the allreduce operation to finish
"""
assert self.handle is not None, "You should launch an allreduce operation before waiting for it to finish"
self.handle.wait() # Block until the all-reduce operation finishes.
def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
"""
Mark a parameter as ready for gradient synchronization. Launches synchronization when all parameters in the
bucket have their gradients ready.
"""
assert param in self.params and param not in self.params_with_grad_ready
self.params_with_grad_ready.add(param)
# When all parameters in the bucket have their gradients ready, synchronize gradients
if len(self.params_with_grad_ready) == len(self.params):
self.sync_gradient()
class BucketManager:
def __init__(self, params: List[torch.nn.Parameter], process_group: torch.distributed.ProcessGroup, bucket_size: int, grad_type: torch.dtype = torch.float32) -> None:
"""
Initializes the BucketManager.
Args:
params (List[torch.nn.Parameter]): List of model parameters.
process_group (torch.distributed.ProcessGroup): Process group used for gradient synchronization.
bucket_size (int): Maximum size of each bucket in terms of gradient elements.
grad_type (torch.dtype, optional): Data type of gradients, defaults to torch.float32.
"""
self.params = list(params) # Convert parameter generator to a list.
self.buckets = [] # List of buckets.
self.process_group = process_group
self.process_group_size = dist.get_world_size(group=self.process_group)
self.params_to_bucket_location = {} # Map each parameter to its corresponding bucket/place (start, end, bucket_idx).
self.bucket_size = bucket_size
self.bucket_sizes = None # Actual sizes of each bucket.
self.grad_data_list = [] # List of tensors to store gradients, one tensor per bucket.
self.grad_type = grad_type
# Divide gradients into buckets based on the provided bucket size.
self._initialize_buckets()
def _initialize_buckets(self) -> None:
"""
Divides model parameters into buckets for gradient synchronization based on the bucket size.
"""
cur_bucket_size = 0
cur_bucket_idx = 0
# Assign parameters to buckets.
for param in self.params:
if not param.requires_grad:
continue
# If the bucket is empty, add the parameter to the bucket.
if cur_bucket_size == 0:
self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
cur_bucket_size = param.numel()
continue
# If the parameter cannot fit in the current bucket, create a new bucket
if cur_bucket_size + param.numel() > self.bucket_size:
cur_bucket_idx += 1
self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
cur_bucket_size = param.numel()
else:
self.params_to_bucket_location[param] = (cur_bucket_size, cur_bucket_size + param.numel(), cur_bucket_idx)
cur_bucket_size += param.numel()
# Gather information about the bucket sizes and the parameters in each bucket
bucket_sizes = [0] * (cur_bucket_idx + 1)
buckets_to_params = [[] for _ in range(cur_bucket_idx + 1)]
for param, (_, end, idx) in self.params_to_bucket_location.items():
bucket_sizes[idx] = max(bucket_sizes[idx], end)
buckets_to_params[idx].append(param)
# Create tensors for storing gradients and initialize Bucket objects.
for i in range(len(bucket_sizes)):
self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device='cuda'))
self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group))
# Create gradient views for each parameter.
for param in self.params[::-1]:
if not param.requires_grad:
continue
data_start_index, data_end_index, bucket_id = self.params_to_bucket_location[param]
# param.main_grad is used for gradient calculation
param.main_grad = self._get_view_from_tensor(self.grad_data_list[bucket_id], param.shape, data_start_index, data_end_index)
def _get_view_from_tensor(self, tensor: torch.Tensor, shape: torch.Size, start: int, end: int) -> torch.Tensor:
"""
Create a view of the given tensor with the specified shape from start to end indices.
"""
return tensor[start:end].view(shape)
def reset(self) -> None:
"""
Reset all buckets by clearing the gradients and internal states.
"""
for bucket in self.buckets:
bucket.reset()
def wait(self) -> None:
"""
Wait for all buckets to complete their gradient synchronization
"""
for bucket in self.buckets:
bucket.wait()
def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
"""
Mark a parameter's gradient as ready for synchronization.
"""
bucket_idx = self.params_to_bucket_location[param][2]
self.buckets[bucket_idx].mark_param_as_ready(param)

View File

@ -0,0 +1,54 @@
import contextlib
import torch
import torch.distributed as dist
from torch import nn
class DataParallel(nn.Module):
def __init__(self, module, process_group):
"""
Initializes the DataParallel wrapper for a given module.
Args:
module (nn.Module): The model to be wrapped for data parallelism.
process_group (torch.distributed.ProcessGroup): The process group used for gradient synchronization.
It could be a data parallel or context parallel group.
"""
super().__init__()
self.module = module
self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group
self.dp_world_size = dist.get_world_size(group=self.process_group)
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
self.register_backward_hook(self._allreduce_grads)
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def register_backward_hook(self, hook):
"""
Registers a backward hook for all parameters of the model that require gradients.
"""
for p in self.module.parameters():
if p.requires_grad is True:
p.register_hook(hook)
def _allreduce_grads(self, grad):
"""
Performs an all-reduce operation to synchronize gradients across multiple processes.
"""
# No synchronization needed during gradient accumulation, except at the final accumulation step.
# 324K tokens/s/gpu -> 334K tokens/s/gpu
if self.require_backward_grad_sync:
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=self.process_group)
grad /= self.dp_world_size
return grad
@contextlib.contextmanager
def no_sync(self):
"""
A context manager to temporarily disable gradient synchronization.
This is useful for performing multiple backward passes during gradient accumulation without synchronizing
gradients in between.
"""
self.require_backward_grad_sync = False
yield
self.require_backward_grad_sync = True

View File

@ -0,0 +1,119 @@
import contextlib
import torch
import torch.distributed as dist
from torch import nn
from torch.autograd import Variable
from src.parallel.data_parallel.bucket import BucketManager
class DataParallel(nn.Module):
def __init__(self, module, process_group, bucket_cap_mb=25, grad_type = torch.float32):
"""
Initialize the DataParallel module.
Args:
module (nn.Module): The model to be parallelized.
process_group: The process group for gradient synchronization, which can be either
a data parallel group or a context parallel group.
bucket_cap_mb (int, optional): The maximum size of each gradient synchronization bucket in megabytes.
Defaults to 25 MB.
grad_type (torch.dtype, optional): The data type of gradients, defaulting to float32.
"""
super().__init__()
self.module = module
self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group
self.dp_world_size = dist.get_world_size(group=self.process_group)
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
self.bucket_manager = BucketManager(module.parameters(), self.process_group, bucket_size, grad_type)
self.register_backward_hook()
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def backward(self, input_tensor, output_tensor, output_tensor_grad):
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
def get_flops(self, *args, **kwargs):
return self.module.get_flops(*args, **kwargs)
def register_backward_hook(self):
"""
Registers a backward hook to manually accumulate and synchronize gradients.
This hook serves two main purposes:
1. PyTorch does not natively support gradient accumulation with mixed precision.
2. After gradient accumulation, it flags parameters as ready for synchronization.
The gradient accumulation functions are stored to prevent them from going out of scope.
References:
- https://github.com/NVIDIA/Megatron-LM/issues/690
- https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html
- https://arxiv.org/abs/2006.15704 (page 5)
"""
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
"""
Creates the a hook for each parameter to handle gradient accumulation and synchronization.
"""
def param_hook(*unused):
"""
The hook called after the gradient is ready. It performs the following:
1. Accumulates the gradient into the main gradient.
2. Adds a post-backward callback to wait for gradient synchronization completion.
3. Marks the parameter as ready for synchronization.
"""
if param.requires_grad:
assert param.grad is not None
param.main_grad.add_(param.grad.data) # accumulate the gradients
param.grad = None
# skip the gradient synchronization (gradient accumulation/PP micro batches)
if self.require_backward_grad_sync:
# Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
# Callback is executed after the backward pass. It should be added per backward pass.
if not self._post_backward_callback_set:
Variable._execution_engine.queue_callback(self._post_backward)
self._post_backward_callback_set = True
# mark the parameter as ready for gradient synchronization.
bucket_manager.mark_param_as_ready(param)
return param_hook
@contextlib.contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronization."""
self.require_backward_grad_sync = False
yield
self.require_backward_grad_sync = True
def _post_backward(self):
"""
A post-backward callback that waits for gradient synchronization to finish, then copies
the synchronized gradients back to the parameters' grad attribute.
This method is called after the backward pass and before the optimizer step.
"""
self.bucket_manager.wait()
self._post_backward_callback_set = False
# copy to params.grad so we can use the optimizer to update the parameters
for p in self.module.parameters():
if p.requires_grad:
p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type.
def reset(self):
"""
Reset the bucket manager and zero out gradients in the model
"""
self.bucket_manager.reset()

View File

@ -265,6 +265,10 @@ if __name__ == "__main__":
trained_tokens += tokens_per_step
step += 1
# In DDP implementation I need to reset the gradient buffers
if hasattr(model, 'reset'):
model.reset()
if pgm.process_group_manager.global_rank == 0:
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
handle.wait()

View File

@ -3,7 +3,7 @@ import random
import numpy as np
import builtins
import fcntl
import distributed.process_group_manager as pgm
import src.distributed.process_group_manager as pgm
def print(*args, **kwargs):
""" solves multi-process interleaved print problem """