2024-10-17 00:48:55 +08:00
import contextlib
import torch
import torch . distributed as dist
from torch import nn
2024-10-18 22:59:26 +08:00
import src . distributed . process_group_manager as pgm
2024-10-17 00:48:55 +08:00
class DataParallel ( nn . Module ) :
2024-10-18 22:59:26 +08:00
def __init__ ( self , module ) :
2024-10-17 00:48:55 +08:00
"""
Initializes the DataParallel wrapper for a given module .
Args :
module ( nn . Module ) : The model to be wrapped for data parallelism .
process_group ( torch . distributed . ProcessGroup ) : The process group used for gradient synchronization .
It could be a data parallel or context parallel group .
"""
super ( ) . __init__ ( )
self . module = module
2024-10-18 22:59:26 +08:00
self . process_group = pgm . process_group_manager . dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
2024-10-17 00:48:55 +08:00
self . dp_world_size = dist . get_world_size ( group = self . process_group )
self . require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
self . register_backward_hook ( self . _allreduce_grads )
def forward ( self , * inputs , * * kwargs ) :
return self . module ( * inputs , * * kwargs )
def register_backward_hook ( self , hook ) :
"""
Registers a backward hook for all parameters of the model that require gradients .
"""
for p in self . module . parameters ( ) :
if p . requires_grad is True :
p . register_hook ( hook )
def _allreduce_grads ( self , grad ) :
"""
Performs an all - reduce operation to synchronize gradients across multiple processes .
"""
# No synchronization needed during gradient accumulation, except at the final accumulation step.
# 324K tokens/s/gpu -> 334K tokens/s/gpu
if self . require_backward_grad_sync :
dist . all_reduce ( grad , op = dist . ReduceOp . SUM , group = self . process_group )
grad / = self . dp_world_size
return grad
@contextlib.contextmanager
def no_sync ( self ) :
"""
A context manager to temporarily disable gradient synchronization .
This is useful for performing multiple backward passes during gradient accumulation without synchronizing
gradients in between .
"""
self . require_backward_grad_sync = False
yield
self . require_backward_grad_sync = True