# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.optim.optimizer import Optimizer, required
from ..common import param_conv
[docs]class QHM(Optimizer):
r"""Implements the quasi-hyperbolic momentum (QHM) optimization algorithm
`(Ma and Yarats, 2019)`_.
Note that many other optimization algorithms are accessible via specific
parameterizations of QHM. See :func:`from_accsgd()`,
:func:`from_robust_momentum()`, etc. for details.
Args:
params (iterable):
iterable of parameters to optimize or dicts defining parameter
groups
lr (float):
learning rate (:math:`\alpha` from the paper)
momentum (float):
momentum factor (:math:`\beta` from the paper)
nu (float):
immediate discount factor (:math:`\nu` from the paper)
weight_decay (float, optional):
weight decay (L2 regularization coefficient, times two)
(default: 0.0)
weight_decay_type (str, optional):
method of applying the weight decay:
``"grad"`` for accumulation in the gradient
(same as :class:`torch.optim.SGD`) or
``"direct"`` for direct application to the parameters
(default: ``"grad"``)
Example:
>>> optimizer = qhoptim.pyt.QHM(
... model.parameters(), lr=1.0, nu=0.7, momentum=0.999)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
.. _`(Ma and Yarats, 2019)`: https://arxiv.org/abs/1810.06801
.. note::
Mathematically, QHM is a simple interpolation between plain SGD and
momentum:
.. math::
\begin{align*}
g_{t + 1} &\leftarrow
\beta \cdot g_t +
(1 - \beta) \cdot \nabla_t \\
\theta_{t + 1} &\leftarrow
\theta_t + \alpha \left[ (1 - \nu) \cdot \nabla_t +
\nu \cdot g_{t + 1} \right]
\end{align*}
Here, :math:`\alpha` is the learning rate, :math:`\beta` is the momentum
factor, and :math:`\nu` is the "immediate discount" factor which
controls the interpolation between plain SGD and momentum.
:math:`g_t` is the momentum buffer, :math:`\theta_t` is the parameter
vector, and :math:`\nabla_t` is the gradient with respect to
:math:`\theta_t`.
.. note::
QHM uses **dampened** momentum. This means that when converting from
plain momentum to QHM, the learning rate must be scaled by
:math:`\frac{1}{1 - \beta}`. For example, momentum with learning rate
:math:`\alpha = 0.1` and momentum :math:`\beta = 0.9` should be
converted to QHM with learning rate :math:`\alpha = 1.0`.
"""
def __init__(self, params, lr=required, momentum=required, nu=required, weight_decay=0.0, weight_decay_type="grad"):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if weight_decay_type not in ("grad", "direct"):
raise ValueError("Invalid weight_decay_type value: {}".format(weight_decay_type))
defaults = {
"lr": lr,
"momentum": momentum,
"nu": nu,
"weight_decay": weight_decay,
"weight_decay_type": weight_decay_type,
}
super(QHM, self).__init__(params, defaults)
[docs] def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional):
A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
lr, nu, momentum = group["lr"], group["nu"], group["momentum"]
weight_decay, weight_decay_type = group["weight_decay"], group["weight_decay_type"]
for p in group["params"]:
if p.grad is None:
continue
d_p = p.grad.data
param_state = self.state[p]
if weight_decay != 0:
if weight_decay_type == "grad":
d_p.add_(weight_decay, p.data)
elif weight_decay_type == "direct":
p.data.mul_(1.0 - lr * weight_decay)
else:
raise ValueError("Invalid weight decay type provided")
if len(param_state) == 0:
param_state["momentum_buffer"] = torch.zeros_like(p.data)
momentum_buffer = param_state["momentum_buffer"]
momentum_buffer.mul_(momentum).add_(1.0 - momentum, d_p)
p.data.add_(-lr * nu, momentum_buffer)
p.data.add_(-lr * (1.0 - nu), d_p)
return loss
@classmethod
def _params_to_dict(cls, params):
return {"lr": params.alpha, "nu": params.nu, "momentum": params.beta}
[docs] @classmethod
def from_pid(cls, k_p, k_i, k_d):
r"""Calculates the QHM hyperparameters required to recover a PID
optimizer as described in `Recht (2018)`_.
Args:
k_p (float):
proportional gain (see reference)
k_i (float):
integral gain (see reference)
k_d (float):
derivative gain (see reference)
Returns:
Three-element ``dict`` containing ``lr``, ``momentum``, and ``nu``
to use in QHM.
Example:
>>> optimizer = qhoptim.pyt.QHM(
... model.parameters(),
... weight_decay=1e-4,
... **qhoptim.pyt.QHM.from_pid(
... k_p=-0.1, k_i=1.0, k_d=3.0))
.. _`Recht (2018)`: https://web.archive.org/web/20181027184056/http://www.argmin.net/2018/04/19/pid/
"""
return cls._params_to_dict(param_conv.from_pid(k_p, k_i, k_d))
[docs] @classmethod
def from_synthesized_nesterov(cls, alpha, beta1, beta2):
r"""Calculates the QHM hyperparameters required to recover the
synthesized Nesterov optimizer (Section 6 of `Lessard et al. (2016)`_).
Args:
alpha (float):
learning rate
beta1 (float):
first momentum (see reference)
beta2 (float):
second momentum (see reference)
Returns:
Three-element ``dict`` containing ``lr``, ``momentum``, and ``nu``
to use in QHM.
Example:
>>> optimizer = qhoptim.pyt.QHM(
... model.parameters(),
... weight_decay=1e-4,
... **qhoptim.pyt.QHM.from_synthesized_nesterov(
... alpha=0.1, beta1=0.9, beta2=0.6))
.. _`Lessard et al. (2016)`: https://arxiv.org/abs/1408.3595
"""
return cls._params_to_dict(param_conv.from_synthesized_nesterov(alpha, beta1, beta2))
[docs] @classmethod
def from_robust_momentum(cls, l, kappa, rho=None):
r"""Calculates the QHM hyperparameters required to recover the Robust
Momentum `(Cyrus et al., 2018)`_ or Triple Momentum
`(Scoy et al., 2018)`_ optimizers.
Args:
l (float):
Lipschitz constant of gradient (see reference)
kappa (float):
condition ratio (see reference)
rho (float, optional):
noise-free convergence rate. If None, will return the
parameters for the Triple Momentum optimizer.
Returns:
Three-element ``dict`` containing ``lr``, ``momentum``, and ``nu``
to use in QHM.
Example:
>>> optimizer = qhoptim.pyt.QHM(
... model.parameters(),
... weight_decay=1e-4,
... **qhoptim.pyt.QHM.from_robust_momentum(
... l=5.0, kappa=15.0))
.. _`(Cyrus et al., 2018)`: https://arxiv.org/abs/1710.04753
.. _`(Scoy et al., 2018)`: http://www.optimization-online.org/DB_FILE/2017/03/5908.pdf
"""
return cls._params_to_dict(param_conv.from_robust_momentum(l, kappa, rho))
[docs] @classmethod
def from_accsgd(cls, delta, kappa, xi, eps=0.7):
r"""Calculates the QHM hyperparameters required to recover the AccSGD
optimizer `(Kidambi et al., 2018)`_.
Args:
delta (float):
short step (see reference)
kappa (float):
long step parameter (see reference)
xi (float):
statistical advantage parameter (see reference)
eps (float, optional):
arbitrary value, between 0 and 1 exclusive (see reference)
(default: 0.7)
Returns:
Three-element ``dict`` containing ``lr``, ``momentum``, and ``nu``
to use in QHM.
Example:
>>> optimizer = qhoptim.pyt.QHM(
... model.parameters(),
... weight_decay=1e-4,
... **qhoptim.pyt.QHM.from_accsgd(
... delta=0.1, kappa=1000.0, xi=10.0))
.. _`(Kidambi et al., 2018)`: https://arxiv.org/abs/1803.05591
"""
return cls._params_to_dict(param_conv.from_accsgd(delta, kappa, xi, eps))
[docs] @classmethod
def from_two_state_optimizer(cls, h, k, l, m, q, z):
r"""Calculates the QHM hyperparameters required to recover the
following optimizer (named "TSO" in `Ma and Yarats (2019)`_):
.. math::
\begin{align*}
a_{t + 1} &\leftarrow
h \cdot a_t + k \cdot \theta_t + l \cdot \nabla_t \\
\theta_{t + 1} &\leftarrow
m \cdot a_t + q \cdot \theta_t + z \cdot \nabla_t
\end{align*}
Here, :math:`a_t` and :math:`\theta_t` are the two states and
:math:`\nabla_t` is the gradient with respect to :math:`\theta_t`.
Be careful that your coefficients satisfy the regularity conditions
from the reference.
Args:
h (float):
see description
k (float):
see description
l (float):
see description
m (float):
see description
q (float):
see description
z (float):
see description
Returns:
Three-element ``dict`` containing ``lr``, ``momentum``, and ``nu``
to use in QHM.
Example:
>>> optimizer = qhoptim.pyt.QHM(
... model.parameters(),
... weight_decay=1e-4,
... **qhoptim.pyt.QHM.from_two_state_optimizer(
... h=0.9, k=0.0, l=0.1, m=-0.09, q=1.0, z=-0.01))
.. _`Ma and Yarats (2019)`: https://arxiv.org/abs/1810.06801
"""
return cls._params_to_dict(param_conv.from_two_state_optimizer(h, k, l, m, q, z))