use global pgm for ddp

This commit is contained in:
ferdinand.mom 2024-10-18 14:59:26 +00:00
parent 134d48b658
commit d0d6d8994f
2 changed files with 7 additions and 4 deletions

View File

@ -2,9 +2,10 @@ import contextlib
import torch
import torch.distributed as dist
from torch import nn
import src.distributed.process_group_manager as pgm
class DataParallel(nn.Module):
def __init__(self, module, process_group):
def __init__(self, module):
"""
Initializes the DataParallel wrapper for a given module.
@ -15,7 +16,7 @@ class DataParallel(nn.Module):
"""
super().__init__()
self.module = module
self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
self.dp_world_size = dist.get_world_size(group=self.process_group)
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
self.register_backward_hook(self._allreduce_grads)

View File

@ -4,9 +4,10 @@ import torch.distributed as dist
from torch import nn
from torch.autograd import Variable
from src.parallel.data_parallel.bucket import BucketManager
import src.distributed.process_group_manager as pgm
class DataParallel(nn.Module):
def __init__(self, module, process_group, bucket_cap_mb=25, grad_type = torch.float32):
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
"""
Initialize the DataParallel module.
@ -20,7 +21,8 @@ class DataParallel(nn.Module):
"""
super().__init__()
self.module = module
self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group
#TODO: refactor by using pgm to get world size etc
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
self.dp_world_size = dist.get_world_size(group=self.process_group)
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes