Source code for neuraltrain.models.dummy_predictor
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torch
from torch import nn
from .base import BaseModelConfig
[docs]
class DummyPredictor(BaseModelConfig):
"""
Dummy predictor that makes predictions using simple rules based on the
target distribution, analogous to ``sklearn.dummy.DummyClassifier``.
Parameters
----------
mode: tp.Literal[
"most_frequent",
"most_frequent_multilabel",
"stratified_multilabel",
"mean",
"auto",
]
Strategy used to derive predictions from the training targets.
- ``"most_frequent"``: predict the most frequent class (single-label
classification, constant output).
- ``"most_frequent_multilabel"``: predict the most frequent binary
value per class independently (multilabel classification, constant
output per class).
- ``"stratified_multilabel"``: sample each class label independently
from a Bernoulli distribution with probability equal to the class
prevalence in the training set (multilabel classification,
stochastic output). Produces macro-F1 scores that reflect the
class prior rather than collapsing to 0 on rare-class tasks.
- ``"mean"``: predict the mean of the targets (regression).
- ``"auto"``: automatically pick a mode based on the dtype and shape
of the targets. Multilabel integer targets resolve to
``"stratified_multilabel"``.
random_state: int | None
Seed used to initialize the ``torch.Generator`` that drives the
Bernoulli sampling in ``stratified_multilabel`` mode. ``None``
(default) falls back to the global RNG state (controlled e.g. by
Lightning's ``seed_everything``). Ignored by the other modes.
"""
mode: tp.Literal[
"most_frequent",
"most_frequent_multilabel",
"stratified_multilabel",
"mean",
"auto",
] = "auto"
random_state: int | None = None
[docs]
def build(self, y_train: torch.Tensor) -> "DummyPredictorModel": # type: ignore[override]
if self.mode == "auto":
if y_train.dtype in (torch.int, torch.int64, torch.long):
mode = "most_frequent"
if y_train.ndim == 2:
n_classes_per_example = (y_train > 0).sum(dim=1)
if (n_classes_per_example == 0).any() or (
n_classes_per_example > 1
).any():
mode = "stratified_multilabel"
elif torch.is_floating_point(y_train):
mode = "mean"
else:
raise ValueError(f"Unsupported dtype: {y_train.dtype}")
else:
mode = self.mode
if mode == "most_frequent":
if y_train.ndim == 1:
most_frequent_ind, _ = torch.mode(y_train)
n_classes = int(y_train.max().item()) + 1
elif y_train.ndim == 2:
most_frequent_ind = y_train.sum(dim=0).argmax()
n_classes = y_train.shape[1]
else:
raise NotImplementedError()
out = torch.nn.functional.one_hot(most_frequent_ind, num_classes=n_classes)
return DummyPredictorModel(out=out.float())
if mode == "most_frequent_multilabel":
out = (y_train > 0).int().mode(dim=0)[0]
return DummyPredictorModel(out=out.float())
if mode == "stratified_multilabel":
if y_train.ndim != 2:
raise ValueError(
f"stratified_multilabel requires 2D targets, got ndim={y_train.ndim}."
)
probs = (y_train > 0).float().mean(dim=0)
return DummyPredictorModel(probs=probs, random_state=self.random_state)
if mode == "mean":
out = y_train.mean(dim=0)
return DummyPredictorModel(out=out.float())
raise ValueError(f"Unsupported mode: {mode}")
[docs]
class DummyPredictorModel(nn.Module):
"""Evaluation-only module that implements the dummy prediction strategies.
Constructed by :meth:`DummyPredictor.build`; not intended to be built
directly by users. Depending on the mode chosen at build time, either
returns a constant tensor tiled across the batch dimension (``out``) or
samples fresh Bernoulli draws per call using a cached ``torch.Generator``
(``probs`` + ``random_state``).
"""
out: torch.Tensor
probs: torch.Tensor
def __init__(
self,
out: torch.Tensor | None = None,
probs: torch.Tensor | None = None,
random_state: int | None = None,
) -> None:
super().__init__()
if (out is None) == (probs is None):
raise ValueError(
"Exactly one of `out` or `probs` must be provided to DummyPredictorModel."
)
self._stratified = probs is not None
if out is not None:
self.register_buffer("out", out)
if probs is not None:
self.register_buffer("probs", probs)
self._random_state = random_state
# Generator objects are device-specific and cannot be registered as
# buffers; cache them lazily per device on first forward call.
self._generators: dict[torch.device, torch.Generator] = {}
def _get_generator(self, device: torch.device) -> torch.Generator | None:
if self._random_state is None:
return None
gen = self._generators.get(device)
if gen is None:
gen = torch.Generator(device=device)
gen.manual_seed(self._random_state)
self._generators[device] = gen
return gen
[docs]
def forward(self, X: torch.Tensor) -> torch.Tensor:
if not self._stratified:
return self.out.repeat(X.shape[0], 1)
probs = self.probs.expand(X.shape[0], -1)
generator = self._get_generator(probs.device)
return torch.bernoulli(probs, generator=generator)