diff --git a/src/parallel/data_parallel/data_parallel.py b/src/parallel/data_parallel/data_parallel.py index dec0104..af0d82d 100644 --- a/src/parallel/data_parallel/data_parallel.py +++ b/src/parallel/data_parallel/data_parallel.py @@ -2,9 +2,10 @@ import contextlib import torch import torch.distributed as dist from torch import nn +import src.distributed.process_group_manager as pgm class DataParallel(nn.Module): - def __init__(self, module, process_group): + def __init__(self, module): """ Initializes the DataParallel wrapper for a given module. @@ -15,7 +16,7 @@ class DataParallel(nn.Module): """ super().__init__() self.module = module - self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group + self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group self.dp_world_size = dist.get_world_size(group=self.process_group) self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation self.register_backward_hook(self._allreduce_grads) diff --git a/src/parallel/data_parallel/data_parallel_bucket.py b/src/parallel/data_parallel/data_parallel_bucket.py index 053598d..e9dd59a 100644 --- a/src/parallel/data_parallel/data_parallel_bucket.py +++ b/src/parallel/data_parallel/data_parallel_bucket.py @@ -4,9 +4,10 @@ import torch.distributed as dist from torch import nn from torch.autograd import Variable from src.parallel.data_parallel.bucket import BucketManager +import src.distributed.process_group_manager as pgm class DataParallel(nn.Module): - def __init__(self, module, process_group, bucket_cap_mb=25, grad_type = torch.float32): + def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32): """ Initialize the DataParallel module. @@ -20,7 +21,8 @@ class DataParallel(nn.Module): """ super().__init__() self.module = module - self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group + #TODO: refactor by using pgm to get world size etc + self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group self.dp_world_size = dist.get_world_size(group=self.process_group) self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes