use global pgm for ddp
This commit is contained in:
parent
2b2781a374
commit
9d53e9afa6
@ -16,8 +16,6 @@ class DataParallel(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
|
||||
self.dp_world_size = dist.get_world_size(group=self.process_group)
|
||||
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
||||
self.register_backward_hook(self._allreduce_grads)
|
||||
|
||||
@ -39,8 +37,8 @@ class DataParallel(nn.Module):
|
||||
# No synchronization needed during gradient accumulation, except at the final accumulation step.
|
||||
# 324K tokens/s/gpu -> 334K tokens/s/gpu
|
||||
if self.require_backward_grad_sync:
|
||||
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
grad /= self.dp_world_size
|
||||
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
|
||||
grad /= pgm.process_group_manager.dp_world_size
|
||||
return grad
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
@ -21,13 +21,10 @@ class DataParallel(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.module = module
|
||||
#TODO: refactor by using pgm to get world size etc
|
||||
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
|
||||
self.dp_world_size = dist.get_world_size(group=self.process_group)
|
||||
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
||||
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
|
||||
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
|
||||
self.bucket_manager = BucketManager(module.parameters(), self.process_group, bucket_size, grad_type)
|
||||
self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.dp_group, bucket_size, grad_type)
|
||||
self.register_backward_hook()
|
||||
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user