Source code for neuraltrain.optimizers.base
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Pydantic configurations for optimizers."""
import typing as tp
import exca
import pydantic
import torch
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from neuraltrain.utils import all_subclasses
TORCH_OPTIMIZER_NAMES = {
cls.__name__: cls for cls in all_subclasses(Optimizer) if cls.__name__ != "NewCls"
}
TORCH_LR_SCHEDULER_NAMES = {cls.__name__: cls for cls in all_subclasses(LRScheduler)}
[docs]
class BaseOptimizer(exca.helpers.DiscriminatedModel, discriminator_key="name"):
"""Base class for optimizer configurations."""
def build(self, params: tp.Iterable[torch.Tensor]) -> Optimizer:
raise NotImplementedError
# Base class for torch optimizers configs (using kwargs pattern)
[docs]
class BaseTorchOptimizer(BaseOptimizer):
"""Base class for torch optimizer configurations."""
_OPTIMIZER_CLASS: tp.ClassVar[type[Optimizer]]
lr: float
kwargs: dict[str, tp.Any] = {}
def build(self, params: tp.Iterable[torch.Tensor]) -> Optimizer:
if "lr" in self.kwargs:
raise ValueError(
"lr should be defined as a base parameter instead of within kwargs."
)
exca.helpers.validate_kwargs(
self._OPTIMIZER_CLASS, self.kwargs | {"params": None}
)
return self._OPTIMIZER_CLASS(params, lr=self.lr, **self.kwargs) # type: ignore[call-arg]
# Generate config classes for all torch optimizers
for optimizer_name, optimizer_class in TORCH_OPTIMIZER_NAMES.items():
optimizer_config_cls: type[BaseTorchOptimizer] = pydantic.create_model( # type: ignore[assignment]
optimizer_name,
__base__=BaseTorchOptimizer,
)
optimizer_config_cls._OPTIMIZER_CLASS = optimizer_class # type: ignore[attr-defined]
globals()[optimizer_name] = optimizer_config_cls
[docs]
class BaseLRScheduler(exca.helpers.DiscriminatedModel, discriminator_key="name"):
"""Base class for learning rate scheduler configurations."""
def build(self, optimizer: Optimizer, **build_kwargs: tp.Any) -> LRScheduler:
raise NotImplementedError
# Base class for torch LR scheduler configs (using kwargs pattern)
[docs]
class BaseTorchLRScheduler(BaseLRScheduler):
"""Base class for torch LR scheduler configurations."""
_SCHEDULER_CLASS: tp.ClassVar[type[LRScheduler]]
kwargs: dict[str, tp.Any] = {}
def build(self, optimizer: Optimizer, **build_kwargs: tp.Any) -> LRScheduler:
exca.helpers.validate_kwargs(
self._SCHEDULER_CLASS, self.kwargs | {"optimizer": None}
)
return self._SCHEDULER_CLASS(optimizer, **(self.kwargs | build_kwargs))
# Generate config classes for all torch LR schedulers
for scheduler_name, scheduler_class in TORCH_LR_SCHEDULER_NAMES.items():
scheduler_config_cls: type[BaseTorchLRScheduler] = pydantic.create_model( # type: ignore[assignment]
scheduler_name,
__base__=BaseTorchLRScheduler,
)
scheduler_config_cls._SCHEDULER_CLASS = scheduler_class # type: ignore[attr-defined]
globals()[scheduler_name] = scheduler_config_cls
[docs]
class LightningOptimizer(BaseOptimizer):
"""Pydantic configuration for Lightning optimizer."""
optimizer: BaseOptimizer
scheduler: BaseLRScheduler | None = None
interval: tp.Literal["step", "epoch"] = "step"
def build( # type: ignore[override]
self,
params: tp.Iterable[torch.Tensor],
**scheduler_build_kwargs: tp.Any,
) -> dict[str, tp.Any]:
out: dict[str, tp.Any] = {"optimizer": self.optimizer.build(params)}
if self.scheduler is not None:
scheduler = self.scheduler.build(out["optimizer"], **scheduler_build_kwargs)
out["lr_scheduler"] = {"scheduler": scheduler, "interval": self.interval}
return out