Source code for neuraltrain.metrics.base

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""Pydantic configurations for metrics."""

import typing as tp
from inspect import isclass

import exca
import pydantic
from torchmetrics import Metric

from neuraltrain.metrics import metrics
from neuraltrain.utils import all_subclasses, convert_to_pydantic

custom_metrics = [
    obj for obj in metrics.__dict__.values() if isclass(obj) and issubclass(obj, Metric)
]


TORCHMETRICS_NAMES = {
    metric_class.__name__: metric_class
    for metric_class in all_subclasses(Metric)
    if metric_class not in custom_metrics
}


[docs] class BaseMetric(exca.helpers.DiscriminatedModel, discriminator_key="name"): """Base class for metric configurations.""" log_name: str def build(self) -> Metric: raise NotImplementedError
# Generate config classes using convert_to_pydantic for custom metrics for metric_class in custom_metrics: config_cls = convert_to_pydantic( metric_class, metric_class.__name__, parent_class=BaseMetric, exclude_from_build=["log_name"], ) globals()[metric_class.__name__] = config_cls # Base class for torchmetrics configs (using kwargs pattern)
[docs] class BaseTorchMetric(BaseMetric): """Base class for torchmetrics configurations.""" _METRIC_CLASS: tp.ClassVar[type[Metric]] kwargs: dict[str, tp.Any] = {} def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) # validation of mandatory/extra args + basic types (str/int/float) exca.helpers.validate_kwargs(self._METRIC_CLASS, self.kwargs) def build(self) -> Metric: return self._METRIC_CLASS(**self.kwargs)
# Generate config classes for all torchmetrics for metric_name, metric_class in TORCHMETRICS_NAMES.items(): torch_config_cls: type[BaseTorchMetric] = pydantic.create_model( # type: ignore[assignment] metric_name, __base__=BaseTorchMetric, ) torch_config_cls._METRIC_CLASS = metric_class # type: ignore[attr-defined] globals()[metric_name] = torch_config_cls