2024-10-17 00:48:55 +08:00
import contextlib
import torch
import torch . distributed as dist
from torch import nn
from torch . autograd import Variable
from src . parallel . data_parallel . bucket import BucketManager
2024-10-18 22:59:26 +08:00
import src . distributed . process_group_manager as pgm
2024-10-17 00:48:55 +08:00
class DataParallel ( nn . Module ) :
2024-10-18 22:59:26 +08:00
def __init__ ( self , module , bucket_cap_mb = 25 , grad_type = torch . float32 ) :
2024-10-17 00:48:55 +08:00
"""
Initialize the DataParallel module .
Args :
module ( nn . Module ) : The model to be parallelized .
process_group : The process group for gradient synchronization , which can be either
a data parallel group or a context parallel group .
bucket_cap_mb ( int , optional ) : The maximum size of each gradient synchronization bucket in megabytes .
Defaults to 25 MB .
grad_type ( torch . dtype , optional ) : The data type of gradients , defaulting to float32 .
"""
super ( ) . __init__ ( )
self . module = module
2024-10-18 22:59:26 +08:00
#TODO: refactor by using pgm to get world size etc
self . process_group = pgm . process_group_manager . dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
2024-10-17 00:48:55 +08:00
self . dp_world_size = dist . get_world_size ( group = self . process_group )
self . require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
grad_size = 2 if grad_type == torch . bfloat16 else 4 # float32 gradient: 4 bytes
bucket_size = bucket_cap_mb * 1024 * 1024 / / grad_size # number of gradients in one bucket
self . bucket_manager = BucketManager ( module . parameters ( ) , self . process_group , bucket_size , grad_type )
self . register_backward_hook ( )
self . _post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
def forward ( self , * inputs , * * kwargs ) :
return self . module ( * inputs , * * kwargs )
def backward ( self , input_tensor , output_tensor , output_tensor_grad ) :
return self . module . backward ( input_tensor , output_tensor , output_tensor_grad )
def get_flops ( self , * args , * * kwargs ) :
return self . module . get_flops ( * args , * * kwargs )
def register_backward_hook ( self ) :
"""
Registers a backward hook to manually accumulate and synchronize gradients .
This hook serves two main purposes :
1. PyTorch does not natively support gradient accumulation with mixed precision .
2. After gradient accumulation , it flags parameters as ready for synchronization .
The gradient accumulation functions are stored to prevent them from going out of scope .
References :
- https : / / github . com / NVIDIA / Megatron - LM / issues / 690
- https : / / pytorch . org / docs / stable / generated / torch . autograd . graph . Node . register_hook . html
- https : / / arxiv . org / abs / 2006.15704 ( page 5 )
"""
self . grad_accs = [ ]
for param in self . module . parameters ( ) :
if param . requires_grad :
# Expand so we get access to grad_fn.
param_tmp = param . expand_as ( param )
# Get the gradient accumulator function.
grad_acc = param_tmp . grad_fn . next_functions [ 0 ] [ 0 ]
grad_acc . register_hook ( self . _make_param_hook ( param , self . bucket_manager ) )
self . grad_accs . append ( grad_acc )
def _make_param_hook ( self , param : torch . nn . Parameter , bucket_manager : BucketManager ) :
"""
Creates the a hook for each parameter to handle gradient accumulation and synchronization .
"""
def param_hook ( * unused ) :
"""
The hook called after the gradient is ready . It performs the following :
1. Accumulates the gradient into the main gradient .
2. Adds a post - backward callback to wait for gradient synchronization completion .
3. Marks the parameter as ready for synchronization .
"""
if param . requires_grad :
assert param . grad is not None
param . main_grad . add_ ( param . grad . data ) # accumulate the gradients
param . grad = None
# skip the gradient synchronization (gradient accumulation/PP micro batches)
if self . require_backward_grad_sync :
# Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
# Callback is executed after the backward pass. It should be added per backward pass.
if not self . _post_backward_callback_set :
Variable . _execution_engine . queue_callback ( self . _post_backward )
self . _post_backward_callback_set = True
# mark the parameter as ready for gradient synchronization.
bucket_manager . mark_param_as_ready ( param )
return param_hook
@contextlib.contextmanager
def no_sync ( self ) :
""" A context manager to disable gradient synchronization. """
self . require_backward_grad_sync = False
yield
self . require_backward_grad_sync = True
def _post_backward ( self ) :
"""
A post - backward callback that waits for gradient synchronization to finish , then copies
the synchronized gradients back to the parameters ' grad attribute.
This method is called after the backward pass and before the optimizer step .
"""
self . bucket_manager . wait ( )
self . _post_backward_callback_set = False
# copy to params.grad so we can use the optimizer to update the parameters
for p in self . module . parameters ( ) :
if p . requires_grad :
p . grad = p . main_grad . to ( p . dtype ) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type.
def reset ( self ) :
"""
Reset the bucket manager and zero out gradients in the model
"""
self . bucket_manager . reset ( )