121 lines
6.0 KiB
Python
121 lines
6.0 KiB
Python
import contextlib
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import nn
|
|
from torch.autograd import Variable
|
|
from src.parallel.data_parallel.bucket import BucketManager
|
|
import src.distributed.process_group_manager as pgm
|
|
|
|
class DataParallel(nn.Module):
|
|
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
|
|
"""
|
|
Initialize the DataParallel module.
|
|
|
|
Args:
|
|
module (nn.Module): The model to be parallelized.
|
|
process_group: The process group for gradient synchronization, which can be either
|
|
a data parallel group or a context parallel group.
|
|
bucket_cap_mb (int, optional): The maximum size of each gradient synchronization bucket in megabytes.
|
|
Defaults to 25 MB.
|
|
grad_type (torch.dtype, optional): The data type of gradients, defaulting to float32.
|
|
"""
|
|
super().__init__()
|
|
self.module = module
|
|
#TODO: refactor by using pgm to get world size etc
|
|
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
|
|
self.dp_world_size = dist.get_world_size(group=self.process_group)
|
|
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
|
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
|
|
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
|
|
self.bucket_manager = BucketManager(module.parameters(), self.process_group, bucket_size, grad_type)
|
|
self.register_backward_hook()
|
|
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
return self.module(*inputs, **kwargs)
|
|
|
|
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
|
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
|
|
|
|
def get_flops(self, *args, **kwargs):
|
|
return self.module.get_flops(*args, **kwargs)
|
|
|
|
def register_backward_hook(self):
|
|
"""
|
|
Registers a backward hook to manually accumulate and synchronize gradients.
|
|
|
|
This hook serves two main purposes:
|
|
1. PyTorch does not natively support gradient accumulation with mixed precision.
|
|
2. After gradient accumulation, it flags parameters as ready for synchronization.
|
|
|
|
The gradient accumulation functions are stored to prevent them from going out of scope.
|
|
|
|
References:
|
|
- https://github.com/NVIDIA/Megatron-LM/issues/690
|
|
- https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html
|
|
- https://arxiv.org/abs/2006.15704 (page 5)
|
|
"""
|
|
self.grad_accs = []
|
|
for param in self.module.parameters():
|
|
if param.requires_grad:
|
|
# Expand so we get access to grad_fn.
|
|
param_tmp = param.expand_as(param)
|
|
# Get the gradient accumulator function.
|
|
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager))
|
|
self.grad_accs.append(grad_acc)
|
|
|
|
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
|
|
"""
|
|
Creates the a hook for each parameter to handle gradient accumulation and synchronization.
|
|
"""
|
|
def param_hook(*unused):
|
|
"""
|
|
The hook called after the gradient is ready. It performs the following:
|
|
1. Accumulates the gradient into the main gradient.
|
|
2. Adds a post-backward callback to wait for gradient synchronization completion.
|
|
3. Marks the parameter as ready for synchronization.
|
|
"""
|
|
if param.requires_grad:
|
|
assert param.grad is not None
|
|
param.main_grad.add_(param.grad.data) # accumulate the gradients
|
|
param.grad = None
|
|
|
|
# skip the gradient synchronization (gradient accumulation/PP micro batches)
|
|
if self.require_backward_grad_sync:
|
|
# Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
|
|
# Callback is executed after the backward pass. It should be added per backward pass.
|
|
if not self._post_backward_callback_set:
|
|
Variable._execution_engine.queue_callback(self._post_backward)
|
|
self._post_backward_callback_set = True
|
|
|
|
# mark the parameter as ready for gradient synchronization.
|
|
bucket_manager.mark_param_as_ready(param)
|
|
return param_hook
|
|
|
|
@contextlib.contextmanager
|
|
def no_sync(self):
|
|
"""A context manager to disable gradient synchronization."""
|
|
self.require_backward_grad_sync = False
|
|
yield
|
|
self.require_backward_grad_sync = True
|
|
|
|
def _post_backward(self):
|
|
"""
|
|
A post-backward callback that waits for gradient synchronization to finish, then copies
|
|
the synchronized gradients back to the parameters' grad attribute.
|
|
|
|
This method is called after the backward pass and before the optimizer step.
|
|
"""
|
|
self.bucket_manager.wait()
|
|
self._post_backward_callback_set = False
|
|
# copy to params.grad so we can use the optimizer to update the parameters
|
|
for p in self.module.parameters():
|
|
if p.requires_grad:
|
|
p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type.
|
|
|
|
def reset(self):
|
|
"""
|
|
Reset the bucket manager and zero out gradients in the model
|
|
"""
|
|
self.bucket_manager.reset() |