use global pgm for ddp
This commit is contained in:
parent
134d48b658
commit
d0d6d8994f
@ -2,9 +2,10 @@ import contextlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
import src.distributed.process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
def __init__(self, module, process_group):
|
||||
def __init__(self, module):
|
||||
"""
|
||||
Initializes the DataParallel wrapper for a given module.
|
||||
|
||||
@ -15,7 +16,7 @@ class DataParallel(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group
|
||||
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
|
||||
self.dp_world_size = dist.get_world_size(group=self.process_group)
|
||||
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
||||
self.register_backward_hook(self._allreduce_grads)
|
||||
|
||||
@ -4,9 +4,10 @@ import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from src.parallel.data_parallel.bucket import BucketManager
|
||||
import src.distributed.process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
def __init__(self, module, process_group, bucket_cap_mb=25, grad_type = torch.float32):
|
||||
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
|
||||
"""
|
||||
Initialize the DataParallel module.
|
||||
|
||||
@ -20,7 +21,8 @@ class DataParallel(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.process_group = process_group # process group for gradient synchronization. could be data parallel group and context parallel group
|
||||
#TODO: refactor by using pgm to get world size etc
|
||||
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
|
||||
self.dp_world_size = dist.get_world_size(group=self.process_group)
|
||||
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
||||
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
|
||||
|
||||
Loading…
Reference in New Issue
Block a user