some dp renaming

This commit is contained in:
ferdinand.mom 2024-11-04 14:41:11 +00:00
parent 814e2a96ad
commit 90868144a7

View File

@ -58,9 +58,9 @@ class DataParallel(nn.Module):
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc)
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc_fn)
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
"""