Source code for neuraltrain.losses.base
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Pydantic configurations for loss functions."""
import typing as tp
from inspect import isclass
import exca
import pydantic
from torch import nn
from torch.nn.modules.loss import _Loss
from neuraltrain.utils import all_subclasses, convert_to_pydantic
from . import losses
custom_losses = [
obj for obj in losses.__dict__.values() if isclass(obj) and issubclass(obj, nn.Module)
]
TORCHLOSS_NAMES = {
loss_class.__name__: loss_class for loss_class in all_subclasses(_Loss)
}
[docs]
class BaseLoss(exca.helpers.DiscriminatedModel, discriminator_key="name"):
"""Base class for loss configurations."""
def build(self, **kwargs: tp.Any) -> nn.Module:
raise NotImplementedError
# Generate config classes using convert_to_pydantic for custom losses
for loss_class in custom_losses:
if loss_class.__name__ == "MultiLoss":
continue
config_cls = convert_to_pydantic(
loss_class, loss_class.__name__, parent_class=BaseLoss
)
globals()[loss_class.__name__] = config_cls
# Base class for torch loss configs (using kwargs pattern)
[docs]
class BaseTorchLoss(BaseLoss):
"""Base class for torch loss configurations."""
_LOSS_CLASS: tp.ClassVar[type[nn.Module]]
kwargs: dict[str, tp.Any] = {}
def model_post_init(self, log__: tp.Any) -> None:
super().model_post_init(log__)
# validation of mandatory/extra args + basic types (str/int/float)
exca.helpers.validate_kwargs(self._LOSS_CLASS, self.kwargs)
def build(self, **kwargs: tp.Any) -> nn.Module:
if overlap := set(self.kwargs) & set(kwargs):
raise ValueError(
f"Build kwargs overlap with config kwargs for keys: {overlap}."
)
kwargs = self.kwargs | kwargs
return self._LOSS_CLASS(**kwargs)
# Generate config classes for all torch losses
for loss_name, loss_class in TORCHLOSS_NAMES.items():
torch_loss_cls: type[BaseTorchLoss] = pydantic.create_model( # type: ignore[assignment]
loss_name,
__base__=BaseTorchLoss,
)
torch_loss_cls._LOSS_CLASS = loss_class # type: ignore[attr-defined]
globals()[loss_name] = torch_loss_cls
class MultiLoss(BaseLoss):
"""Pydantic configuration for multi-loss."""
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
losses: BaseLoss | dict[str, BaseLoss]
weights: dict[str, float] | None = None
def model_post_init(self, log__: tp.Any) -> None:
super().model_post_init(log__)
if isinstance(self.losses, dict) and isinstance(self.weights, dict):
diff = set(self.losses).symmetric_difference(self.weights)
if diff:
raise ValueError(f"weights and losses key differ: {diff}")
def build(self, **kwargs: tp.Any) -> nn.Module:
if not isinstance(self.losses, dict):
return self.losses.build()
built_losses = {name: loss.build() for name, loss in self.losses.items()}
return losses.MultiLoss(built_losses, weights=self.weights)