# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torch
from torch import nn
from .base import BaseModelConfig
[docs]
class DummyPredictor(BaseModelConfig):
"""
Dummy predictor that makes predictions using simple rules based on the
target distribution, analogous to ``sklearn.dummy.DummyClassifier``.
Parameters
----------
mode: tp.Literal[
"most_frequent",
"most_frequent_multilabel",
"stratified_multilabel",
"mean",
"auto",
]
Strategy used to derive predictions from the training targets.
- ``"most_frequent"``: predict the most frequent class (single-label
classification, constant output).
- ``"most_frequent_multilabel"``: predict the most frequent binary
value per class independently (multilabel classification, constant
output per class).
- ``"stratified_multilabel"``: sample each class label independently
from a Bernoulli distribution with probability equal to the class
prevalence in the training set (multilabel classification,
stochastic output). Produces macro-F1 scores that reflect the
class prior rather than collapsing to 0 on rare-class tasks.
- ``"mean"``: predict the mean of the targets (regression).
- ``"auto"``: automatically pick a mode based on the dtype and shape
of the targets. Multilabel integer targets resolve to
``"stratified_multilabel"``.
random_state: int | None
Seed used to initialize the ``torch.Generator`` that drives the
Bernoulli sampling in ``stratified_multilabel`` mode. ``None``
(default) falls back to the global RNG state (controlled e.g. by
Lightning's ``seed_everything``). Ignored by the other modes.
"""
mode: tp.Literal[
"most_frequent",
"most_frequent_multilabel",
"stratified_multilabel",
"mean",
"auto",
] = "auto"
random_state: int | None = None
[docs]
def build( # type: ignore[override]
self,
y_train: torch.Tensor,
blank_idx: int | None = None,
n_classes: int | None = None,
) -> "DummyPredictorModel | DummyCtcSequenceModel":
# CTC sequence tasks: ``y_train`` is ``(n_samples, max_length)`` of
# blank-padded label ids rather than a class / multilabel row. When
# the caller signals a CTC objective (``blank_idx`` set), ignore the
# classification ``mode`` and emit a constant most-frequent-character
# sequence (see ``DummyCtcSequenceModel``).
if blank_idx is not None:
return self._build_ctc_sequence(y_train, blank_idx, n_classes)
if self.mode == "auto":
if y_train.dtype in (torch.int, torch.int64, torch.long):
mode = "most_frequent"
if y_train.ndim == 2:
n_classes_per_example = (y_train > 0).sum(dim=1)
if (n_classes_per_example == 0).any() or (
n_classes_per_example > 1
).any():
mode = "stratified_multilabel"
elif torch.is_floating_point(y_train):
mode = "mean"
else:
raise ValueError(f"Unsupported dtype: {y_train.dtype}")
else:
mode = self.mode
if mode == "most_frequent":
if y_train.ndim == 1:
most_frequent_ind, _ = torch.mode(y_train)
n_classes = int(y_train.max().item()) + 1
elif y_train.ndim == 2:
most_frequent_ind = y_train.sum(dim=0).argmax()
n_classes = y_train.shape[1]
else:
raise NotImplementedError()
out = torch.nn.functional.one_hot(most_frequent_ind, num_classes=n_classes)
return DummyPredictorModel(out=out.float())
if mode == "most_frequent_multilabel":
out = (y_train > 0).int().mode(dim=0)[0]
return DummyPredictorModel(out=out.float())
if mode == "stratified_multilabel":
if y_train.ndim != 2:
raise ValueError(
f"stratified_multilabel requires 2D targets, got ndim={y_train.ndim}."
)
probs = (y_train > 0).float().mean(dim=0)
return DummyPredictorModel(probs=probs, random_state=self.random_state)
if mode == "mean":
out = y_train.mean(dim=0)
return DummyPredictorModel(out=out.float())
raise ValueError(f"Unsupported mode: {mode}")
@staticmethod
def _build_ctc_sequence(
y_train: torch.Tensor, blank_idx: int, n_classes: int | None
) -> "DummyCtcSequenceModel":
"""Most-frequent-character CTC baseline (analogue of ``most_frequent``).
Given blank-padded label sequences ``y_train`` of shape
``(n_samples, max_length)``, predict the most frequent non-blank
character repeated ``L_median`` times, where ``L_median`` is the
median number of real (non-blank) tokens per sequence. The median
minimises the L1 length penalty under edit-distance error rates.
"""
y = y_train.long()
if y.ndim != 2:
raise ValueError(
f"CTC dummy expects 2D (n_samples, max_length) targets; "
f"got ndim={y.ndim}."
)
if n_classes is None:
n_classes = int(y.max().item()) + 1
non_blank_mask = y != blank_idx
median_length = int(non_blank_mask.sum(dim=1).float().median().item())
non_blank = y[non_blank_mask]
if non_blank.numel() == 0 or median_length == 0:
# No real keystrokes -> predict the empty string.
modal_char = blank_idx
median_length = 0
else:
modal_char = int(torch.mode(non_blank).values.item())
return DummyCtcSequenceModel(
modal_char=modal_char,
length=median_length,
blank_idx=blank_idx,
n_classes=n_classes,
)
[docs]
class DummyPredictorModel(nn.Module):
"""Evaluation-only module that implements the dummy prediction strategies.
Constructed by :meth:`DummyPredictor.build`; not intended to be built
directly by users. Depending on the mode chosen at build time, either
returns a constant tensor tiled across the batch dimension (``out``) or
samples fresh Bernoulli draws per call using a cached ``torch.Generator``
(``probs`` + ``random_state``).
"""
out: torch.Tensor
probs: torch.Tensor
def __init__(
self,
out: torch.Tensor | None = None,
probs: torch.Tensor | None = None,
random_state: int | None = None,
) -> None:
super().__init__()
if (out is None) == (probs is None):
raise ValueError(
"Exactly one of `out` or `probs` must be provided to DummyPredictorModel."
)
self._stratified = probs is not None
if out is not None:
self.register_buffer("out", out)
if probs is not None:
self.register_buffer("probs", probs)
self._random_state = random_state
# Generator objects are device-specific and cannot be registered as
# buffers; cache them lazily per device on first forward call.
self._generators: dict[torch.device, torch.Generator] = {}
def _get_generator(self, device: torch.device) -> torch.Generator | None:
if self._random_state is None:
return None
gen = self._generators.get(device)
if gen is None:
gen = torch.Generator(device=device)
gen.manual_seed(self._random_state)
self._generators[device] = gen
return gen
[docs]
def forward(self, X: torch.Tensor) -> torch.Tensor:
if not self._stratified:
return self.out.repeat(X.shape[0], 1)
probs = self.probs.expand(X.shape[0], -1)
generator = self._get_generator(probs.device)
return torch.bernoulli(probs, generator=generator)
class DummyCtcSequenceModel(nn.Module):
"""Constant CTC emitter for the most-frequent-character baseline.
Emits input-independent per-frame scores of shape
``(batch, n_frames, n_classes)`` whose greedy CTC decode (collapse
consecutive repeats + drop the blank, as
:class:`neuraltrain.metrics.CharacterErrorRates` does) is exactly
``modal_char`` repeated ``length`` times. The modal character is
interleaved with the blank class (``[c, blank, c, blank, ..., c]``) so
the repeats survive consecutive-collapse and the decoded length stays
``length`` even when ``modal_char`` would otherwise merge.
The non-target classes are filled with a large negative score so the
emissions double as valid (near one-hot) log-probabilities for
:class:`torch.nn.CTCLoss`; ``argmax`` is unaffected by the magnitude.
"""
emissions: torch.Tensor
def __init__(
self, modal_char: int, length: int, blank_idx: int, n_classes: int
) -> None:
super().__init__()
n_frames = max(2 * length - 1, 1)
emissions = torch.full((n_frames, n_classes), -30.0)
for t in range(n_frames):
cls = modal_char if t % 2 == 0 else blank_idx
emissions[t, cls] = 0.0
self.register_buffer("emissions", emissions)
def forward(self, X: torch.Tensor) -> torch.Tensor:
return self.emissions.to(X.device).unsqueeze(0).expand(X.shape[0], -1, -1)