use global pgm for ddp

This commit is contained in:
ferdinand.mom 2024-10-18 15:34:17 +00:00
parent 2b2781a374
commit 9d53e9afa6
2 changed files with 3 additions and 8 deletions

View File

@ -16,8 +16,6 @@ class DataParallel(nn.Module):
""" """
super().__init__() super().__init__()
self.module = module self.module = module
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
self.dp_world_size = dist.get_world_size(group=self.process_group)
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
self.register_backward_hook(self._allreduce_grads) self.register_backward_hook(self._allreduce_grads)
@ -39,8 +37,8 @@ class DataParallel(nn.Module):
# No synchronization needed during gradient accumulation, except at the final accumulation step. # No synchronization needed during gradient accumulation, except at the final accumulation step.
# 324K tokens/s/gpu -> 334K tokens/s/gpu # 324K tokens/s/gpu -> 334K tokens/s/gpu
if self.require_backward_grad_sync: if self.require_backward_grad_sync:
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=self.process_group) dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
grad /= self.dp_world_size grad /= pgm.process_group_manager.dp_world_size
return grad return grad
@contextlib.contextmanager @contextlib.contextmanager

View File

@ -21,13 +21,10 @@ class DataParallel(nn.Module):
""" """
super().__init__() super().__init__()
self.module = module self.module = module
#TODO: refactor by using pgm to get world size etc
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
self.dp_world_size = dist.get_world_size(group=self.process_group)
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
self.bucket_manager = BucketManager(module.parameters(), self.process_group, bucket_size, grad_type) self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.dp_group, bucket_size, grad_type)
self.register_backward_hook() self.register_backward_hook()
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set