diff --git a/src/parallel/data_parallel/data_parallel.py b/src/parallel/data_parallel/data_parallel.py index af0d82d..af39dc2 100644 --- a/src/parallel/data_parallel/data_parallel.py +++ b/src/parallel/data_parallel/data_parallel.py @@ -16,8 +16,6 @@ class DataParallel(nn.Module): """ super().__init__() self.module = module - self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group - self.dp_world_size = dist.get_world_size(group=self.process_group) self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation self.register_backward_hook(self._allreduce_grads) @@ -39,8 +37,8 @@ class DataParallel(nn.Module): # No synchronization needed during gradient accumulation, except at the final accumulation step. # 324K tokens/s/gpu -> 334K tokens/s/gpu if self.require_backward_grad_sync: - dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=self.process_group) - grad /= self.dp_world_size + dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group) + grad /= pgm.process_group_manager.dp_world_size return grad @contextlib.contextmanager diff --git a/src/parallel/data_parallel/data_parallel_bucket.py b/src/parallel/data_parallel/data_parallel_bucket.py index e9dd59a..c5b1c2a 100644 --- a/src/parallel/data_parallel/data_parallel_bucket.py +++ b/src/parallel/data_parallel/data_parallel_bucket.py @@ -21,13 +21,10 @@ class DataParallel(nn.Module): """ super().__init__() self.module = module - #TODO: refactor by using pgm to get world size etc - self.process_group = pgm.process_group_manager.dp_group # process group for gradient synchronization. could be data parallel group and context parallel group - self.dp_world_size = dist.get_world_size(group=self.process_group) self.require_backward_grad_sync = True # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation grad_size = 2 if grad_type == torch.bfloat16 else 4 # float32 gradient: 4 bytes bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size # number of gradients in one bucket - self.bucket_manager = BucketManager(module.parameters(), self.process_group, bucket_size, grad_type) + self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.dp_group, bucket_size, grad_type) self.register_backward_hook() self._post_backward_callback_set = False # whether the callback for wait gradient synchronization is set