Quasi-hyperbolic optimizers for PyTorch¶
Getting started¶
The PyTorch optimizer classes are qhoptim.pyt.QHM
and
qhoptim.pyt.QHAdam
.
Use these optimizers as you would any other PyTorch optimizer:
>>> from qhoptim.pyt import QHM, QHAdam
# something like this for QHM
>>> optimizer = QHM(model.parameters(), lr=1.0, nu=0.7, momentum=0.999)
# or something like this for QHAdam
>>> optimizer = QHAdam(
... model.parameters(), lr=1e-3, nus=(0.7, 1.0), betas=(0.995, 0.999))
# a single optimization step
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
QHM API reference¶
-
class
qhoptim.pyt.
QHM
(params, lr=<required parameter>, momentum=<required parameter>, nu=<required parameter>, weight_decay=0.0, weight_decay_type='grad')[source]¶ Implements the quasi-hyperbolic momentum (QHM) optimization algorithm (Ma and Yarats, 2019).
Note that many other optimization algorithms are accessible via specific parameterizations of QHM. See
from_accsgd()
,from_robust_momentum()
, etc. for details.- Parameters
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float) – learning rate (\(\alpha\) from the paper)
momentum (float) – momentum factor (\(\beta\) from the paper)
nu (float) – immediate discount factor (\(\nu\) from the paper)
weight_decay (float, optional) – weight decay (L2 regularization coefficient, times two) (default: 0.0)
weight_decay_type (str, optional) – method of applying the weight decay:
"grad"
for accumulation in the gradient (same astorch.optim.SGD
) or"direct"
for direct application to the parameters (default:"grad"
)
Example
>>> optimizer = qhoptim.pyt.QHM( ... model.parameters(), lr=1.0, nu=0.7, momentum=0.999) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Mathematically, QHM is a simple interpolation between plain SGD and momentum:
\[\begin{split}\begin{align*} g_{t + 1} &\leftarrow \beta \cdot g_t + (1 - \beta) \cdot \nabla_t \\ \theta_{t + 1} &\leftarrow \theta_t + \alpha \left[ (1 - \nu) \cdot \nabla_t + \nu \cdot g_{t + 1} \right] \end{align*}\end{split}\]Here, \(\alpha\) is the learning rate, \(\beta\) is the momentum factor, and \(\nu\) is the “immediate discount” factor which controls the interpolation between plain SGD and momentum. \(g_t\) is the momentum buffer, \(\theta_t\) is the parameter vector, and \(\nabla_t\) is the gradient with respect to \(\theta_t\).
Note
QHM uses dampened momentum. This means that when converting from plain momentum to QHM, the learning rate must be scaled by \(\frac{1}{1 - \beta}\). For example, momentum with learning rate \(\alpha = 0.1\) and momentum \(\beta = 0.9\) should be converted to QHM with learning rate \(\alpha = 1.0\).
-
classmethod
from_accsgd
(delta, kappa, xi, eps=0.7)[source]¶ Calculates the QHM hyperparameters required to recover the AccSGD optimizer (Kidambi et al., 2018).
- Parameters
- Returns
Three-element
dict
containinglr
,momentum
, andnu
to use in QHM.
Example
>>> optimizer = qhoptim.pyt.QHM( ... model.parameters(), ... weight_decay=1e-4, ... **qhoptim.pyt.QHM.from_accsgd( ... delta=0.1, kappa=1000.0, xi=10.0))
-
classmethod
from_pid
(k_p, k_i, k_d)[source]¶ Calculates the QHM hyperparameters required to recover a PID optimizer as described in Recht (2018).
- Parameters
- Returns
Three-element
dict
containinglr
,momentum
, andnu
to use in QHM.
Example
>>> optimizer = qhoptim.pyt.QHM( ... model.parameters(), ... weight_decay=1e-4, ... **qhoptim.pyt.QHM.from_pid( ... k_p=-0.1, k_i=1.0, k_d=3.0))
-
classmethod
from_robust_momentum
(l, kappa, rho=None)[source]¶ Calculates the QHM hyperparameters required to recover the Robust Momentum (Cyrus et al., 2018) or Triple Momentum (Scoy et al., 2018) optimizers.
- Parameters
- Returns
Three-element
dict
containinglr
,momentum
, andnu
to use in QHM.
Example
>>> optimizer = qhoptim.pyt.QHM( ... model.parameters(), ... weight_decay=1e-4, ... **qhoptim.pyt.QHM.from_robust_momentum( ... l=5.0, kappa=15.0))
-
classmethod
from_synthesized_nesterov
(alpha, beta1, beta2)[source]¶ Calculates the QHM hyperparameters required to recover the synthesized Nesterov optimizer (Section 6 of Lessard et al. (2016)).
- Parameters
- Returns
Three-element
dict
containinglr
,momentum
, andnu
to use in QHM.
Example
>>> optimizer = qhoptim.pyt.QHM( ... model.parameters(), ... weight_decay=1e-4, ... **qhoptim.pyt.QHM.from_synthesized_nesterov( ... alpha=0.1, beta1=0.9, beta2=0.6))
-
classmethod
from_two_state_optimizer
(h, k, l, m, q, z)[source]¶ Calculates the QHM hyperparameters required to recover the following optimizer (named “TSO” in Ma and Yarats (2019)):
\[\begin{split}\begin{align*} a_{t + 1} &\leftarrow h \cdot a_t + k \cdot \theta_t + l \cdot \nabla_t \\ \theta_{t + 1} &\leftarrow m \cdot a_t + q \cdot \theta_t + z \cdot \nabla_t \end{align*}\end{split}\]Here, \(a_t\) and \(\theta_t\) are the two states and \(\nabla_t\) is the gradient with respect to \(\theta_t\).
Be careful that your coefficients satisfy the regularity conditions from the reference.
- Parameters
- Returns
Three-element
dict
containinglr
,momentum
, andnu
to use in QHM.
Example
>>> optimizer = qhoptim.pyt.QHM( ... model.parameters(), ... weight_decay=1e-4, ... **qhoptim.pyt.QHM.from_two_state_optimizer( ... h=0.9, k=0.0, l=0.1, m=-0.09, q=1.0, z=-0.01))
QHAdam API reference¶
-
class
qhoptim.pyt.
QHAdam
(params, lr=0.001, betas=(0.9, 0.999), nus=(1.0, 1.0), weight_decay=0.0, decouple_weight_decay=False, eps=1e-08)[source]¶ Implements the QHAdam optimization algorithm (Ma and Yarats, 2019).
Note that the NAdam optimizer is accessible via a specific parameterization of QHAdam. See
from_nadam()
for details.- Parameters
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate (\(\alpha\) from the paper) (default: 1e-3)
betas (Tuple[float, float], optional) – coefficients used for computing running averages of the gradient and its square (default: (0.9, 0.999))
nus (Tuple[float, float], optional) – immediate discount factors used to estimate the gradient and its square (default: (1.0, 1.0))
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) – weight decay (default: 0.0)
decouple_weight_decay (bool, optional) – whether to decouple the weight decay from the gradient-based optimization step (default: False)
Example
>>> optimizer = qhoptim.pyt.QHAdam( ... model.parameters(), ... lr=3e-4, nus=(0.8, 1.0), betas=(0.99, 0.999)) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
-
classmethod
from_nadam
(lr=0.001, betas=(0.9, 0.999))[source]¶ Calculates the QHAdam hyperparameters required to recover the NAdam optimizer (Dozat, 2016).
This is not an identical recovery of the formulation in the paper, due to subtle differences in the application of the bias correction in the first moment estimator. However, in practice, this difference is almost certainly irrelevant.
- Parameters
- Returns
Three-element
dict
containinglr
,betas
, andnus
to use in QHAdam.
Example
>>> optimizer = qhoptim.pyt.QHAdam( ... model.parameters(), ... weight_decay=1e-4, ... **qhoptim.pyt.QHAdam.from_nadam( ... lr=1e-3, betas=(0.9, 0.999)))
-
qhoptim.pyt.
QHAdamW
(params, *args, **kwargs)[source]¶ Constructs the decoupled decay variant of the QHAdam optimization algorithm (Ma and Yarats, 2019), as proposed by Loschilov and Hutter (2017).
Shares all arguments of the
QHAdam
constructor – equivalent to constructingQHAdam
withdecouple_weight_decay=True
.