use global pgm for ddp
This commit is contained in:
parent
2b2781a374
commit
9d53e9afa6
@ -16,8 +16,6 @@ class DataParallel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
|
|
||||||
self.dp_world_size = dist.get_world_size(group=self.process_group)
|
|
||||||
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
||||||
self.register_backward_hook(self._allreduce_grads)
|
self.register_backward_hook(self._allreduce_grads)
|
||||||
|
|
||||||
@ -39,8 +37,8 @@ class DataParallel(nn.Module):
|
|||||||
# No synchronization needed during gradient accumulation, except at the final accumulation step.
|
# No synchronization needed during gradient accumulation, except at the final accumulation step.
|
||||||
# 324K tokens/s/gpu -> 334K tokens/s/gpu
|
# 324K tokens/s/gpu -> 334K tokens/s/gpu
|
||||||
if self.require_backward_grad_sync:
|
if self.require_backward_grad_sync:
|
||||||
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=self.process_group)
|
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
|
||||||
grad /= self.dp_world_size
|
grad /= pgm.process_group_manager.dp_world_size
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|||||||
@ -21,13 +21,10 @@ class DataParallel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
#TODO: refactor by using pgm to get world size etc
|
|
||||||
self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group
|
|
||||||
self.dp_world_size = dist.get_world_size(group=self.process_group)
|
|
||||||
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
|
||||||
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
|
grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes
|
||||||
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
|
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket
|
||||||
self.bucket_manager = BucketManager(module.parameters(), self.process_group, bucket_size, grad_type)
|
self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.dp_group, bucket_size, grad_type)
|
||||||
self.register_backward_hook()
|
self.register_backward_hook()
|
||||||
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
|
self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user