Source code for neuraltrain.models.constant_predictor
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torch
from torch import nn
from .base import BaseModelConfig
[docs]
class ConstantPredictor(BaseModelConfig):
"""
Constant predictor that predicts the most frequent class or the mean of the targets.
Parameters
----------
mode: tp.Literal["most_frequent", "most_frequent_multilabel", "mean", "auto"]
Mode to use for obtaining the constant from the targets. Options:
- "most_frequent": Predicts the most frequent class (for single-label classification).
- "most_frequent_multilabel": Predicts the most frequent value for each class independently
(for multilabel classification where multiple classes can be active simultaneously).
- "mean": Predicts the mean of the targets (for regression).
- "auto": Automatically determines the mode based on the dtype and shape of the targets.
"""
mode: tp.Literal["most_frequent", "most_frequent_multilabel", "mean", "auto"] = "auto"
def build(self, y_train: torch.Tensor) -> "ConstantPredictorModel": # type: ignore[override]
if self.mode == "auto":
if y_train.dtype in (torch.int, torch.int64, torch.long):
mode = "most_frequent"
if y_train.ndim == 2:
n_classes_per_example = (y_train > 0).sum(dim=1)
if (n_classes_per_example == 0).any() or (
n_classes_per_example > 1
).any():
mode = "most_frequent_multilabel"
elif torch.is_floating_point(y_train):
mode = "mean"
else:
raise ValueError(f"Unsupported dtype: {y_train.dtype}")
else:
mode = self.mode
if mode == "most_frequent":
if y_train.ndim == 1:
most_frequent_ind, _ = torch.mode(y_train)
n_classes = int(y_train.max().item()) + 1
elif y_train.ndim == 2:
most_frequent_ind = y_train.sum(dim=0).argmax()
n_classes = y_train.shape[1]
else:
raise NotImplementedError()
out = torch.nn.functional.one_hot(most_frequent_ind, num_classes=n_classes)
elif mode == "most_frequent_multilabel":
out = (y_train > 0).int().mode(dim=0)[0]
elif mode == "mean":
out = y_train.mean(dim=0)
else:
raise ValueError(f"Unsupported dtype: {y_train.dtype}")
return ConstantPredictorModel(out=out.float())
[docs]
class ConstantPredictorModel(nn.Module):
out: torch.Tensor
def __init__(self, out: torch.Tensor):
super().__init__()
self.register_buffer("out", out)
[docs]
def forward(self, X: torch.Tensor) -> torch.Tensor:
return self.out.repeat(X.shape[0], 1)