diff --git a/src/parallel/data_parallel/data_parallel_bucket.py b/src/parallel/data_parallel/data_parallel_bucket.py index 13909fe..4423d6f 100644 --- a/src/parallel/data_parallel/data_parallel_bucket.py +++ b/src/parallel/data_parallel/data_parallel_bucket.py @@ -58,9 +58,9 @@ class DataParallel(nn.Module): # Expand so we get access to grad_fn. param_tmp = param.expand_as(param) # Get the gradient accumulator function. - grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager)) - self.grad_accs.append(grad_acc) + grad_acc_fn = param_tmp.grad_fn.next_functions[0][0] + grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager)) + self.grad_accs.append(grad_acc_fn) def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager): """