# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Evaluation metrics."""
# pylint: disable=attribute-defined-outside-init
import typing as tp
from collections import defaultdict
import numpy as np
import pandas as pd
import scipy as sp
import torch
import torchmetrics
import torchmetrics.text # noqa: F401 -- registers ``torchmetrics.text.CharErrorRate``
import torchvision.models as tvmodels
import torchvision.transforms as T
from scipy.stats import binom
from torchmetrics.utilities.data import dim_zero_cat
[docs]
class OnlinePearsonCorr(torchmetrics.regression.PearsonCorrCoef):
"""
Online Pearson correlation coefficient.
This class computes the Pearson correlation coefficient in an online fashion,
updating the metric with each new batch of predictions and targets.
Parameters
----------
dim : int
The dimension along which to compute the correlation coefficient.
- ``dim=0``: correlate across samples (each column is a separate output).
Uses Welford online accumulation, so batch sizes may vary.
- ``dim=1``: correlate across features/time within each sample, then
reduce across samples. Computes per-sample correlations per batch and
accumulates them, so it is correct across multiple batches.
reduction : {"mean", "sum", "none"}, optional
Specifies how to reduce the computed correlation coefficients. Defaults to "mean".
torchmetrics_kwargs : dict or None
Extra keyword arguments forwarded to the ``torchmetrics.Metric``
constructor (e.g. ``dist_sync_on_step``).
"""
def __init__(
self,
dim: int,
reduction: tp.Literal["mean", "sum", "none"] | None = "mean",
torchmetrics_kwargs: dict[str, tp.Any] | None = None,
):
torchmetrics_kwargs = torchmetrics_kwargs or {}
super().__init__(**torchmetrics_kwargs)
self.dim = dim
self.reduction = reduction
self._initialized = False
self.sample_corrs: list[torch.Tensor]
if self.dim == 1:
self.add_state(
"sample_corrs",
default=[],
dist_reduce_fx="cat",
)
@staticmethod
def _batch_pearsonr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Compute per-sample Pearson r along dim=1.
Parameters
----------
x, y : torch.Tensor of shape ``(B, T)``
Returns
-------
torch.Tensor of shape ``(B,)``
"""
x_centered = x - x.mean(dim=1, keepdim=True)
y_centered = y - y.mean(dim=1, keepdim=True)
cov = (x_centered * y_centered).sum(dim=1)
std_x = x_centered.pow(2).sum(dim=1).sqrt()
std_y = y_centered.pow(2).sum(dim=1).sqrt()
denom = std_x * std_y
return torch.where(denom > 0, cov / denom, float("nan"))
[docs]
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
if self.dim == 1:
corrs = self._batch_pearsonr(preds, target)
self.sample_corrs.append(corrs)
return
# dim=0: online Welford accumulation via parent class
if not self._initialized:
self.num_outputs = preds.shape[1]
state_names = [
"mean_x",
"mean_y",
"max_abs_dev_x",
"max_abs_dev_y",
"var_x",
"var_y",
"corr_xy",
"n_total",
]
for state_name in state_names:
self.add_state(
state_name,
default=torch.zeros(self.num_outputs, device=self.device),
dist_reduce_fx=None,
)
self._initialized = True
super().update(preds, target)
[docs]
def compute(self):
if self.dim == 1:
corrs = dim_zero_cat(self.sample_corrs)
if self.reduction == "mean":
return torch.nanmean(corrs)
elif self.reduction == "sum":
return torch.nansum(corrs)
else:
return corrs
corrcoef = super().compute()
if self.reduction == "mean":
return torch.nanmean(corrcoef)
elif self.reduction == "sum":
return torch.nansum(corrcoef)
else: # No reduction
return corrcoef
[docs]
def reset(self) -> None:
self._initialized = False
super().reset()
class NormalizedRMSE(torchmetrics.regression.MeanSquaredError):
"""Root-mean-squared error divided by the standard deviation of the targets.
Official metric of Challenge 2 of the NeurIPS 2025 EEG Foundation Challenge
(https://eeg2025.github.io/). Extends ``torchmetrics.MeanSquaredError`` and
inherits its running-sum RMSE accumulation; on top of that, we accumulate
first and second moments of the targets so that ``std`` is computed over the
full epoch's targets rather than per-batch.
Parameters
----------
num_outputs : int
Number of outputs. Forwarded to ``MeanSquaredError`` and used to size
the target-moment buffers.
"""
higher_is_better: bool = False
def __init__(self, num_outputs: int = 1) -> None:
super().__init__(squared=False, num_outputs=num_outputs)
self.add_state(
"target_sum",
default=torch.zeros(num_outputs),
dist_reduce_fx="sum",
)
self.add_state(
"target_sq_sum",
default=torch.zeros(num_outputs),
dist_reduce_fx="sum",
)
self.add_state(
"n_target",
default=torch.tensor(0),
dist_reduce_fx="sum",
)
self.target_sum: torch.Tensor
self.target_sq_sum: torch.Tensor
self.n_target: torch.Tensor
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
super().update(preds, target)
t = target.detach()
if t.ndim == 1:
t = t.unsqueeze(-1)
t = t.reshape(-1, t.shape[-1]).to(self.target_sum.dtype)
self.target_sum = self.target_sum + t.sum(dim=0)
self.target_sq_sum = self.target_sq_sum + (t**2).sum(dim=0)
self.n_target = self.n_target + t.shape[0]
def compute(self) -> torch.Tensor:
rmse = super().compute()
n = self.n_target.to(self.target_sum.dtype).clamp(min=1)
mean = self.target_sum / n
var = (self.target_sq_sum / n) - mean**2
std = torch.sqrt(torch.clamp(var, min=0.0))
return rmse / std
[docs]
class Rank(torchmetrics.Metric):
"""Rank of predictions based on a retrieval set, using cosine similarity.
Parameters
----------
reduction : {"mean", "median", "std"}
How to reduce the example-wise ranks.
relative : bool
If True, divide ranks by the retrieval-set size so that values lie
in [0, 1].
torchmetrics_kwargs : dict or None
Extra keyword arguments forwarded to the ``torchmetrics.Metric``
constructor.
"""
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = True
def __init__(
self,
reduction: tp.Literal["mean", "median", "std"] = "median",
relative: bool = False,
torchmetrics_kwargs: dict[str, tp.Any] | None = None,
):
torchmetrics_kwargs = torchmetrics_kwargs or {}
super().__init__(**torchmetrics_kwargs)
self.reduction = reduction
self.relative = relative
self.add_state(
"ranks",
default=torch.Tensor([]),
dist_reduce_fx="cat",
)
self.rank_count: torch.Tensor # For mypy
@classmethod
def _compute_sim(cls, x, y, norm_kind="y", eps=1e-15):
if norm_kind is None:
eq, inv_norms = "b", torch.ones(x.shape[0])
elif norm_kind == "x":
eq, inv_norms = "b", 1 / (eps + x.norm(dim=(1), p=2))
elif norm_kind == "y":
eq, inv_norms = "o", 1 / (eps + y.norm(dim=(1), p=2))
elif norm_kind == "xy":
eq = "bo"
inv_norms = 1 / (
eps + torch.outer(x.norm(dim=(1), p=2), y.norm(dim=(1), p=2))
)
else:
raise ValueError(f"norm must be None, x, y or xy, got {norm_kind}.")
# Normalize inside einsum to avoid creating a copy of candidates which can be pretty big
return torch.einsum(f"bc,oc,{eq}->bo", x, y, inv_norms)
@staticmethod
def _compute_ranks_from_scores(
scores: torch.Tensor,
true_scores: torch.Tensor,
retrieval_size: int | None,
) -> torch.Tensor:
"""Average ranks obtained with strictly greater-than and greater-than-or-equals operations
to account for repeated scores.
E.g., the zero-based rank of prediction "1" in [0, 1, 1, 1, 2] will be 2 (instead of 1
or 3).
"""
ranks_gt = (scores > true_scores).nansum(dim=1)
ranks_ge = (scores >= true_scores).nansum(dim=1) - 1
ranks = (ranks_gt + ranks_ge) / 2
ranks[ranks < 0] = len(scores) // 2 # FIXME
if retrieval_size is not None:
ranks /= retrieval_size
return ranks
def _compute_ranks(
self,
x: torch.Tensor,
y: torch.Tensor,
x_labels: None | list[str] = None,
y_labels: None | list[str] = None,
) -> torch.Tensor:
scores = self._compute_sim(x, y)
if x_labels is not None and y_labels is not None:
# Use explicit mapping to match predictions and targets
true_inds = torch.tensor(
[y_labels.index(x) for x in x_labels],
dtype=torch.long,
device=scores.device,
)[:, None]
true_scores = torch.take_along_dim(scores, true_inds, dim=1)
else:
# Assume 1:1 mapping of predictions and targets
if x_labels is not None or y_labels is not None:
raise ValueError(
"x_labels and y_labels must both be None or both provided"
)
if x.shape[0] != y.shape[0]:
raise ValueError(
f"x and y must have same first dim, got {x.shape=} vs {y.shape=}"
)
true_scores = torch.diag(scores)[:, None]
return self._compute_ranks_from_scores(
scores, true_scores, len(y) if self.relative else None
)
[docs]
@torch.inference_mode()
def update(
self,
x: torch.Tensor,
y: torch.Tensor,
x_labels: None | list[str] = None,
y_labels: None | list[str] = None,
) -> None:
"""Update internal list of ranks.
Parameters
----------
x :
Tensor of predictions, of shape (N, F).
y :
Tensor of retrieval set examples, of shape (M, F).
x_labels, y_labels :
If provided, used to match predictions and ground truths that don't have the same
number of examples. Should have length of N and M, respectively
"""
ranks = self._compute_ranks(x, y, x_labels, y_labels)
self.ranks = torch.cat([self.ranks, ranks]) # type: ignore
[docs]
def compute(self) -> torch.Tensor:
agg_func: tp.Callable
if self.reduction == "mean":
agg_func = torch.mean
elif self.reduction == "median":
agg_func = torch.median
elif self.reduction == "std":
agg_func = torch.std
else:
raise ValueError(
f'Unknown aggregation {self.reduction} for computing metric. Available aggregations are: "mean", "median" or "std".'
)
return agg_func(self.ranks)
def _compute_macro_average(
self, ranks: torch.Tensor, labels: list[str]
) -> dict[str, float]:
"""
Compute the average rank for each class.
"""
if len(ranks) != len(labels):
raise ValueError(f"ranks/labels mismatch: {len(ranks)} vs {len(labels)}")
groups = defaultdict(list)
agg_func = np.mean if self.reduction == "mean" else np.median
for i, label in enumerate(labels):
groups[label].append(ranks[i])
return {label: agg_func(ranks) for label, ranks in groups.items()} # type: ignore
@classmethod
def _compute_topk_scores(
cls,
x: torch.Tensor,
y: torch.Tensor,
y_labels: list[str] | None = None,
k: int = 5,
) -> tuple[list[list[str]] | None, torch.Tensor, list[list[float]]]:
"""
Compute the top-k predictions and scores for each example in x.
If y_labels are provided, the function will return the actual top-k labels for each input example as well as the indices and similarity scores.
"""
scores = cls._compute_sim(x, y)
topk_inds = torch.argsort(scores, dim=1, descending=True)[:, :k]
scores = [
[scores[i, ind].item() for ind in inds] for i, inds in enumerate(topk_inds)
]
topk_labels = None
if y_labels is not None:
topk_labels = [[y_labels[ind] for ind in inds] for inds in topk_inds]
return topk_labels, topk_inds, scores
[docs]
class TopkAcc(Rank):
"""Top-k accuracy.
Parameters
----------
topk :
K in top-k, i.e. minimal rank to classify a prediction as a success.
"""
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = True
def __init__(self, topk: int = 5):
super().__init__(relative=False)
self.topk = topk
def _compute_macro_average(
self, ranks: torch.Tensor, labels: list[str]
) -> dict[str, float]:
"""
Compute the top-k accuracy for each class.
"""
groups = defaultdict(list)
for i, label in enumerate(labels):
groups[label].append(ranks[i])
return {
label: float(np.mean([r < self.topk for r in ranks]))
for label, ranks in groups.items()
} # type: ignore
[docs]
def compute(self) -> torch.Tensor:
ranks = self.ranks
return (ranks < self.topk).float().mean()
[docs]
class TopkAccFromScores(TopkAcc):
"""Top-k accuracy computed from already available similarity scores.
Parameters
----------
topk :
K in top-k, i.e. minimal rank to classify a prediction as a success.
true_labels :
Defines where to look for the scores of true pairs. If "first", use the scores in the first
column of the scores matrix; if "diagonal", use the scores on the diagonal.
"""
def __init__(
self, topk: int = 5, true_labels: tp.Literal["first", "diagonal"] = "first"
):
super().__init__(topk)
self.true_labels = true_labels
def _compute_ranks(self, scores: torch.Tensor) -> torch.Tensor: # type: ignore[override]
if self.true_labels == "first":
true_scores = scores[:, [0]]
elif self.true_labels == "diagonal":
true_scores = torch.diag(scores)
else:
raise RuntimeError
return self._compute_ranks_from_scores(scores, true_scores, None)
[docs]
@torch.inference_mode()
def update(self, scores: torch.Tensor) -> None: # type: ignore[override]
"""Update internal list of ranks."""
ranks = self._compute_ranks(scores)
self.ranks = torch.cat([self.ranks, ranks]) # type: ignore
[docs]
class ImageSimilarity(torchmetrics.Metric):
"""Image similarity metric based on feature extraction from a pretrained network.
Code adapted from:
https://github.com/ozcelikfu/brain-diffuser/blob/main/scripts/evaluate_reconstruction.py
https://github.com/ozcelikfu/brain-diffuser/blob/main/scripts/eval_extract_features.py
Parameters
----------
model_name : {"inceptionv3", "alexnet", "clip", "efficientnet", "swav"}
Pretrained network used for feature extraction.
layer : str or int
Layer of the network to extract features from. Valid values depend
on *model_name*.
torchmetrics_kwargs : dict or None
Extra keyword arguments forwarded to the ``torchmetrics.Metric``
constructor.
"""
def __init__(
self,
model_name: tp.Literal[
"inceptionv3", "alexnet", "clip", "efficientnet", "swav"
] = "inceptionv3",
layer: str | int = "avgpool",
torchmetrics_kwargs: dict[str, tp.Any] | None = None,
):
torchmetrics_kwargs = torchmetrics_kwargs or {}
super().__init__(**torchmetrics_kwargs)
self.feat_list: list[torch.Tensor] = []
if model_name == "inceptionv3":
net = tvmodels.inception_v3(pretrained=True)
elif model_name == "alexnet":
net = tvmodels.alexnet(pretrained=True)
elif model_name == "clip":
import clip
model, _ = clip.load("ViT-L/14")
net = model.visual
net = net.to(torch.float32)
elif model_name == "efficientnet":
net = tvmodels.efficientnet_b1(weights=True)
elif model_name == "swav":
net = torch.hub.load("facebookresearch/swav:main", "resnet50")
self.net = net.float().eval()
self.add_state("pred_features", [], dist_reduce_fx=None)
self.add_state("true_features", [], dist_reduce_fx=None)
if model_name == "clip":
self.normalize = T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
)
else:
self.normalize = T.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
if model_name in ["efficientnet", "swav"]:
self.distance_fn = sp.spatial.distance.correlation
self.model_name = model_name
self.layer = layer
def add_forward_hook_to_net(self, layer: str | int) -> None:
if self.model_name == "inceptionv3":
if layer == "avgpool":
self.net.avgpool.register_forward_hook(self.fn)
elif layer == "lastconv":
self.net.Mixed_7c.register_forward_hook(self.fn)
else:
raise ValueError(
f"Unknown layer {layer} for InceptionV3. Available layers are: 'avgpool' or 'lastconv'."
)
elif self.model_name == "alexnet":
if layer == 2:
self.net.features[4].register_forward_hook(self.fn)
elif layer == 5:
self.net.features[11].register_forward_hook(self.fn)
elif layer == 7:
self.net.classifier[5].register_forward_hook(self.fn)
else:
raise ValueError(
f"Unknown layer {layer} for AlexNet. Available layers are: 2, 5 or 7."
)
elif self.model_name == "clip":
if layer == 7:
self.net.transformer.resblocks[7].register_forward_hook(self.fn)
elif layer == 12:
self.net.transformer.resblocks[12].register_forward_hook(self.fn)
elif layer == "final":
self.net.register_forward_hook(self.fn)
else:
raise ValueError(
f"Unknown layer {layer} for CLIP. Available layers are: 7, 12 or 'final'."
)
elif self.model_name == "efficientnet":
self.net.avgpool.register_forward_hook(self.fn)
elif self.model_name == "swav":
self.net.avgpool.register_forward_hook(self.fn)
def fn(self, module, inputs, outputs):
self.feat_list.append(outputs)
[docs]
@torch.inference_mode()
def update(
self,
preds: torch.Tensor,
trues: torch.Tensor,
) -> None:
"""Update internal list of ranks.
Parameters
----------
preds :
Tensor of predictions, of shape (N, 3, H, W).
trues :
Tensor of retrieval set examples, of shape (N, 3, H',W').
"""
# Make sure that the net has a forward hook --> as with lightning callbacks, forward hooks are removed...
self.add_forward_hook_to_net(self.layer)
if len(preds) != len(trues):
raise ValueError(f"preds/trues length mismatch: {len(preds)} vs {len(trues)}")
n = len(preds)
self.feat_list: list[torch.Tensor] = [] # type: ignore
# data transform
preds = T.functional.resize(preds, (224, 224))
trues = T.functional.resize(trues, (224, 224))
preds_trues = torch.cat([preds, trues], dim=0)
preds_trues = self.normalize(preds_trues).float()
# make sure that the net is in float
self.net = self.net.float().to(self.device)
_ = self.net(preds_trues.to(self.device))
features: torch.Tensor | None = None
if self.model_name == "clip":
if self.layer in (7, 12):
features = torch.cat(self.feat_list, dim=1).permute(1, 0, 2) # type: ignore
else:
features = torch.cat(self.feat_list, dim=0) # type: ignore
else:
features = torch.cat(self.feat_list, dim=0) # type: ignore
self.pred_features.append(features[:n]) # type: ignore
self.true_features.append(features[n:]) # type: ignore
def _pairwise_corr_all(
self,
ground_truth: np.ndarray | torch.Tensor,
predictions: np.ndarray | torch.Tensor,
) -> tp.Tuple[np.ndarray, np.ndarray]:
if isinstance(ground_truth, torch.Tensor):
ground_truth = ground_truth.cpu()
if isinstance(predictions, torch.Tensor):
predictions = predictions.cpu()
r = np.corrcoef(ground_truth, predictions)
r = r[: len(ground_truth), len(ground_truth) :]
# congruent pairs are on diagonal
congruents = np.diag(r)
# for each column (prediction) we count the number of rows (ground truth) for which the value is lower than the congruent (e.g. success).
success = r < congruents
success_cnt = np.sum(success, 0)
# note: diagonal of 'success' is always zero so we can discard it. That's why we divide by len-1
perf = np.mean(success_cnt) / (len(ground_truth) - 1)
p = 1 - binom.cdf(
perf * len(ground_truth) * (len(ground_truth) - 1),
len(ground_truth) * (len(ground_truth) - 1),
0.5,
)
return perf, p
[docs]
def compute(self) -> torch.Tensor:
preds_feat_tensor = dim_zero_cat(self.pred_features) # type: ignore
trues_feat_tensor = dim_zero_cat(self.true_features) # type: ignore
preds_feat: np.ndarray = (
preds_feat_tensor.reshape((len(preds_feat_tensor), -1)).cpu().numpy()
)
trues_feat: np.ndarray = (
trues_feat_tensor.reshape((len(trues_feat_tensor), -1)).cpu().numpy()
)
n = len(preds_feat)
if self.model_name in ["efficientnet", "swav"]:
distances = np.array(
[self.distance_fn(trues_feat[i], preds_feat[i]) for i in range(n)]
)
val = np.mean(distances)
else:
val = self._pairwise_corr_all(trues_feat, preds_feat)[0]
return torch.tensor(val, device=self.pred_features[0].device) # type: ignore
[docs]
class GroupedMetric(torchmetrics.Metric):
"""
A wrapper around a torchmetrics.Metric that allows for computing metrics per group.
IMPORTANT: this metric does not work well with LightningModule, because the
self.log() method does not support dictionaries of metrics.
To use this metric, you need to add this in the on_val_epoch_end and on_test_epoch_end methods:
metric_dict = {metric_name + "/" + k: v for k, v in grouped_metric.compute().items()}
self.log_dict(metric_dict)
grouped_metric.reset()
"""
def __init__(self, metric_name: str, kwargs: dict[str, tp.Any] | None = None) -> None:
super().__init__()
if kwargs is None:
kwargs = {}
from neuraltrain.metrics.base import TORCHMETRICS_NAMES
if metric_name in TORCHMETRICS_NAMES:
self.base_metric_cls = TORCHMETRICS_NAMES[metric_name]
else:
metric_cls = globals().get(metric_name)
if metric_cls is None:
raise ValueError(f"Metric {metric_name} not found")
self.base_metric_cls = metric_cls
self.metric_kwargs = kwargs
self.metrics = torch.nn.ModuleDict() # store metrics per group
[docs]
def update(
self,
preds: torch.Tensor,
target: torch.Tensor,
groups: torch.Tensor | None = None,
) -> None:
"""
Update each group's metric separately.
groups: a tensor or list of group identifiers, same shape as preds/target.
"""
if groups is None:
groups = torch.zeros(preds.shape[0])
else:
groups = groups.flatten()
if len(groups) != preds.shape[0]:
raise ValueError(
f"groups length ({len(groups)}) must match preds ({preds.shape[0]})"
)
# Use pandas.groupby, faster than computing masks by hand
groups_df = pd.DataFrame({"label": groups.tolist()})
for group_id, group in groups_df.groupby("label", sort=False):
mask = group.index.to_numpy()
group_preds = preds[mask]
group_target = target[mask]
group_key = str(group_id)
if group_key not in self.metrics:
self.metrics[group_key] = self.base_metric_cls(**self.metric_kwargs)
self.metrics[group_key] = self.metrics[group_key].to(preds.device)
self.metrics[group_key].update(group_preds, group_target) # type: ignore
[docs]
def compute(self) -> dict[str, float]:
# Return a dictionary of group_id: computed_metric
return {
gid: metric.compute().item() # type: ignore[operator]
for gid, metric in self.metrics.items() # type: ignore
}
[docs]
def reset(self) -> None:
for metric in self.metrics.values():
metric.reset() # type: ignore
def __repr__(self) -> str:
return f"GroupedMetric({self.base_metric_cls.__name__})"
[docs]
class CharacterErrorRates(torchmetrics.text.CharErrorRate):
"""CTC greedy-decoded character error rate, returned in percent.
Wraps :class:`torchmetrics.text.CharErrorRate` for a CTC head:
greedy-decodes ``y_pred`` (collapse repeats + drop blanks), maps the
integer label IDs to a private per-character alphabet via ``chr``,
and delegates the Levenshtein accumulation to the parent class.
Inputs
------
y_pred : ``(B, T_out, n_classes)``
Log-probs from a CTC head.
y_true : ``(B, max_target_length)``
Integer labels padded with ``blank_idx``. Per-row label counts
are recovered as ``(y_true != blank_idx).sum(-1)``.
"""
def __init__(self, blank_idx: int = 0) -> None:
super().__init__()
self._blank = blank_idx
[docs]
def update( # type: ignore[override]
self, y_pred: torch.Tensor, y_true: torch.Tensor
) -> None:
argmax = y_pred.argmax(dim=-1).long()
blank = self._blank
targets = y_true.long()
# Coalesce the per-row sync into a single D2H copy.
target_lengths = (targets != blank).sum(dim=-1).tolist()
preds_str: list[str] = []
targets_str: list[str] = []
for i, target_len in enumerate(target_lengths):
preds_i = torch.unique_consecutive(argmax[i])
preds_i = preds_i[preds_i != blank].tolist()
targets_i = targets[i, :target_len].tolist()
# chr() gives each label id a distinct single-codepoint
# "character" so torchmetrics' string-typed CER sees the
# right alphabet without us touching its accumulation.
preds_str.append("".join(chr(p) for p in preds_i))
targets_str.append("".join(chr(t) for t in targets_i))
super().update(preds_str, targets_str)
[docs]
def compute(self) -> torch.Tensor:
# Parent returns CER in [0, 1]; keep the historical percent
# contract. ``total`` may be 0 if no batches have updated yet
# (e.g. an all-empty validation split); avoid a NaN/Inf return.
if int(self.total) == 0:
return torch.zeros_like(self.errors)
return super().compute() * 100.0