some dp renaming
This commit is contained in:
parent
814e2a96ad
commit
90868144a7
@ -58,9 +58,9 @@ class DataParallel(nn.Module):
|
||||
# Expand so we get access to grad_fn.
|
||||
param_tmp = param.expand_as(param)
|
||||
# Get the gradient accumulator function.
|
||||
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
||||
grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager))
|
||||
self.grad_accs.append(grad_acc)
|
||||
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
|
||||
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
|
||||
self.grad_accs.append(grad_acc_fn)
|
||||
|
||||
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user