281 lines
11 KiB
Python
281 lines
11 KiB
Python
# Copied from https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
from typing import Iterable, Optional
|
|
import weakref
|
|
import copy
|
|
import contextlib
|
|
|
|
import torch
|
|
|
|
|
|
def to_float_maybe(x):
|
|
return x.float() if x.dtype in [torch.float16, torch.bfloat16] else x
|
|
|
|
|
|
# Partially based on:
|
|
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
|
|
class ExponentialMovingAverage:
|
|
"""
|
|
Maintains (exponential) moving average of a set of parameters.
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter` (typically from
|
|
`model.parameters()`).
|
|
decay: The exponential decay.
|
|
use_num_updates: Whether to use number of updates when computing
|
|
averages.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
parameters: Iterable[torch.nn.Parameter],
|
|
decay: float,
|
|
use_num_updates: bool = True
|
|
):
|
|
if decay < 0.0 or decay > 1.0:
|
|
raise ValueError('Decay must be between 0 and 1')
|
|
self.decay = decay
|
|
self.num_updates = 0 if use_num_updates else None
|
|
parameters = list(parameters)
|
|
self.shadow_params = [to_float_maybe(p.clone().detach())
|
|
for p in parameters if p.requires_grad]
|
|
self.collected_params = None
|
|
# By maintaining only a weakref to each parameter,
|
|
# we maintain the old GC behaviour of ExponentialMovingAverage:
|
|
# if the model goes out of scope but the ExponentialMovingAverage
|
|
# is kept, no references to the model or its parameters will be
|
|
# maintained, and the model will be cleaned up.
|
|
self._params_refs = [weakref.ref(p) for p in parameters]
|
|
|
|
def _get_parameters(
|
|
self,
|
|
parameters: Optional[Iterable[torch.nn.Parameter]]
|
|
) -> Iterable[torch.nn.Parameter]:
|
|
if parameters is None:
|
|
parameters = [p() for p in self._params_refs]
|
|
if any(p is None for p in parameters):
|
|
raise ValueError(
|
|
"(One of) the parameters with which this "
|
|
"ExponentialMovingAverage "
|
|
"was initialized no longer exists (was garbage collected);"
|
|
" please either provide `parameters` explicitly or keep "
|
|
"the model to which they belong from being garbage "
|
|
"collected."
|
|
)
|
|
return parameters
|
|
else:
|
|
parameters = list(parameters)
|
|
if len(parameters) != len(self.shadow_params):
|
|
raise ValueError(
|
|
"Number of parameters passed as argument is different "
|
|
"from number of shadow parameters maintained by this "
|
|
"ExponentialMovingAverage"
|
|
)
|
|
return parameters
|
|
|
|
def update(
|
|
self,
|
|
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
|
) -> None:
|
|
"""
|
|
Update currently maintained parameters.
|
|
Call this every time the parameters are updated, such as the result of
|
|
the `optimizer.step()` call.
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
|
parameters used to initialize this object. If `None`, the
|
|
parameters with which this `ExponentialMovingAverage` was
|
|
initialized will be used.
|
|
"""
|
|
parameters = self._get_parameters(parameters)
|
|
decay = self.decay
|
|
if self.num_updates is not None:
|
|
self.num_updates += 1
|
|
decay = min(
|
|
decay,
|
|
(1 + self.num_updates) / (10 + self.num_updates)
|
|
)
|
|
one_minus_decay = 1.0 - decay
|
|
if parameters[0].device != self.shadow_params[0].device:
|
|
self.to(device=parameters[0].device)
|
|
with torch.no_grad():
|
|
parameters = [p for p in parameters if p.requires_grad]
|
|
for s_param, param in zip(self.shadow_params, parameters):
|
|
torch.lerp(s_param, param.to(dtype=s_param.dtype), one_minus_decay, out=s_param)
|
|
|
|
def copy_to(
|
|
self,
|
|
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
|
) -> None:
|
|
"""
|
|
Copy current averaged parameters into given collection of parameters.
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
updated with the stored moving averages. If `None`, the
|
|
parameters with which this `ExponentialMovingAverage` was
|
|
initialized will be used.
|
|
"""
|
|
parameters = self._get_parameters(parameters)
|
|
for s_param, param in zip(self.shadow_params, parameters):
|
|
if param.requires_grad:
|
|
param.data.copy_(s_param.data)
|
|
|
|
def store(
|
|
self,
|
|
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
|
) -> None:
|
|
"""
|
|
Save the current parameters for restoring later.
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
temporarily stored. If `None`, the parameters of with which this
|
|
`ExponentialMovingAverage` was initialized will be used.
|
|
"""
|
|
parameters = self._get_parameters(parameters)
|
|
self.collected_params = [
|
|
param.clone()
|
|
for param in parameters
|
|
if param.requires_grad
|
|
]
|
|
|
|
def restore(
|
|
self,
|
|
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
|
) -> None:
|
|
"""
|
|
Restore the parameters stored with the `store` method.
|
|
Useful to validate the model with EMA parameters without affecting the
|
|
original optimization process. Store the parameters before the
|
|
`copy_to` method. After validation (or model saving), use this to
|
|
restore the former parameters.
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
updated with the stored parameters. If `None`, the
|
|
parameters with which this `ExponentialMovingAverage` was
|
|
initialized will be used.
|
|
"""
|
|
if self.collected_params is None:
|
|
raise RuntimeError(
|
|
"This ExponentialMovingAverage has no `store()`ed weights "
|
|
"to `restore()`"
|
|
)
|
|
parameters = self._get_parameters(parameters)
|
|
for c_param, param in zip(self.collected_params, parameters):
|
|
if param.requires_grad:
|
|
param.data.copy_(c_param.data)
|
|
|
|
@contextlib.contextmanager
|
|
def average_parameters(
|
|
self,
|
|
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
|
):
|
|
r"""
|
|
Context manager for validation/inference with averaged parameters.
|
|
Equivalent to:
|
|
ema.store()
|
|
ema.copy_to()
|
|
try:
|
|
...
|
|
finally:
|
|
ema.restore()
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
updated with the stored parameters. If `None`, the
|
|
parameters with which this `ExponentialMovingAverage` was
|
|
initialized will be used.
|
|
"""
|
|
parameters = self._get_parameters(parameters)
|
|
self.store(parameters)
|
|
self.copy_to(parameters)
|
|
try:
|
|
yield
|
|
finally:
|
|
self.restore(parameters)
|
|
|
|
def to(self, device=None, dtype=None) -> None:
|
|
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
|
Args:
|
|
device: like `device` argument to `torch.Tensor.to`
|
|
"""
|
|
# .to() on the tensors handles None correctly
|
|
self.shadow_params = [
|
|
p.to(device=device, dtype=dtype)
|
|
if p.is_floating_point()
|
|
else p.to(device=device)
|
|
for p in self.shadow_params
|
|
]
|
|
if self.collected_params is not None:
|
|
self.collected_params = [
|
|
p.to(device=device, dtype=dtype)
|
|
if p.is_floating_point()
|
|
else p.to(device=device)
|
|
for p in self.collected_params
|
|
]
|
|
return
|
|
|
|
def state_dict(self) -> dict:
|
|
r"""Returns the state of the ExponentialMovingAverage as a dict."""
|
|
# Following PyTorch conventions, references to tensors are returned:
|
|
# "returns a reference to the state and not its copy!" -
|
|
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
|
return {
|
|
"decay": self.decay,
|
|
"num_updates": self.num_updates,
|
|
"shadow_params": self.shadow_params,
|
|
"collected_params": self.collected_params
|
|
}
|
|
|
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
r"""Loads the ExponentialMovingAverage state.
|
|
Args:
|
|
state_dict (dict): EMA state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
# deepcopy, to be consistent with module API
|
|
state_dict = copy.deepcopy(state_dict)
|
|
self.decay = state_dict["decay"]
|
|
if self.decay < 0.0 or self.decay > 1.0:
|
|
raise ValueError('Decay must be between 0 and 1')
|
|
self.num_updates = state_dict["num_updates"]
|
|
assert self.num_updates is None or isinstance(self.num_updates, int), \
|
|
"Invalid num_updates"
|
|
|
|
self.shadow_params = state_dict["shadow_params"]
|
|
assert isinstance(self.shadow_params, list), \
|
|
"shadow_params must be a list"
|
|
assert all(
|
|
isinstance(p, torch.Tensor) for p in self.shadow_params
|
|
), "shadow_params must all be Tensors"
|
|
|
|
self.collected_params = state_dict["collected_params"]
|
|
if self.collected_params is not None:
|
|
assert isinstance(self.collected_params, list), \
|
|
"collected_params must be a list"
|
|
assert all(
|
|
isinstance(p, torch.Tensor) for p in self.collected_params
|
|
), "collected_params must all be Tensors"
|
|
assert len(self.collected_params) == len(self.shadow_params), \
|
|
"collected_params and shadow_params had different lengths"
|
|
|
|
if len(self.shadow_params) == len(self._params_refs):
|
|
# Consistent with torch.optim.Optimizer, cast things to consistent
|
|
# device and dtype with the parameters
|
|
params = [p() for p in self._params_refs]
|
|
# If parameters have been garbage collected, just load the state
|
|
# we were given without change.
|
|
if not any(p is None for p in params):
|
|
# ^ parameter references are still good
|
|
for i, p in enumerate(params):
|
|
self.shadow_params[i] = to_float_maybe(self.shadow_params[i].to(
|
|
device=p.device, dtype=p.dtype
|
|
))
|
|
if self.collected_params is not None:
|
|
self.collected_params[i] = self.collected_params[i].to(
|
|
device=p.device, dtype=p.dtype
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Tried to `load_state_dict()` with the wrong number of "
|
|
"parameters in the saved state."
|
|
)
|