some dp renaming
This commit is contained in:
parent
814e2a96ad
commit
90868144a7
@ -58,9 +58,9 @@ class DataParallel(nn.Module):
|
|||||||
# Expand so we get access to grad_fn.
|
# Expand so we get access to grad_fn.
|
||||||
param_tmp = param.expand_as(param)
|
param_tmp = param.expand_as(param)
|
||||||
# Get the gradient accumulator function.
|
# Get the gradient accumulator function.
|
||||||
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
|
||||||
grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager))
|
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
|
||||||
self.grad_accs.append(grad_acc)
|
self.grad_accs.append(grad_acc_fn)
|
||||||
|
|
||||||
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
|
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user