# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import typing as tp
import warnings
from collections import defaultdict
from pathlib import Path
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torchmetrics
from lightning.pytorch.callbacks import Callback
from matplotlib.figure import Figure
from sklearn.metrics import ConfusionMatrixDisplay
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from neuraltrain.metrics.utils import agg_per_group, agg_retrieval_preds
if tp.TYPE_CHECKING:
from neuraltrain.utils import StandardScaler
LOGGER = logging.getLogger(__name__)
def _set_plot_theme() -> None:
"""Apply the plotting theme used by neuralbench callbacks."""
sns.set_theme(context="paper", style="white")
[docs]
class TestFullRetrievalMetrics(Callback):
"""Accumulate predictions on entire test set before evaluating metrics.
Requires the pl.LightningModule object this callback is attached to to define a
torch.ModuleDict named `retrieval_metrics` containing the metrics to evaluate.
"""
def __init__(
self,
event_type: tp.Literal["Word", "Image"] = "Word",
event_field: tp.Literal["text", "category"] = "text",
retrieval_set_sizes: tuple = (None, 250),
save_outputs: bool = False,
logger: tp.Any = None,
eval_val: bool = False,
):
self.event_type = event_type
self.event_field = event_field
self.retrieval_set_sizes = retrieval_set_sizes
self.save_outputs = save_outputs
self.logger = logger
self.eval_val = eval_val
self.full_outputs: dict[str | int, tp.Any] = {}
[docs]
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str):
# Ensure retrieval metrics are valid
if not hasattr(pl_module, "test_full_retrieval_metrics") or not isinstance(
pl_module.test_full_retrieval_metrics, nn.ModuleDict
):
raise ValueError(
"The LightningModule needs a test_full_retrieval_metrics ModuleDict "
"that contains the retrieval metrics to evaluate on the full test set."
)
test_full_metrics = {
k: v
for k, v in pl_module.test_full_retrieval_metrics.items()
if ((k.startswith("val") or k.startswith("test")) and "retrieval" in k)
}
if not test_full_metrics:
raise ValueError(
"The LightningModule needs at least one retrieval metric with 'retrieval' in its name."
)
def get_n_loaders(
self, trainer: pl.Trainer, step_name: tp.Literal["val", "test"]
) -> int:
loader = getattr(trainer, step_name + "_dataloaders")
if loader is None:
n_loaders = 0
elif isinstance(loader, DataLoader):
n_loaders = 1
else:
n_loaders = len(loader)
return n_loaders
[docs]
def on_validation_epoch_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
assert trainer.val_dataloaders is not None
self.full_outputs = {
idx: defaultdict(list) for idx in range(len(trainer.val_dataloaders))
}
self.full_outputs["y_true"] = dict()
[docs]
def on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
assert trainer.test_dataloaders is not None
self.full_outputs = {
idx: defaultdict(list) for idx in range(len(trainer.test_dataloaders))
}
self.full_outputs["y_true"] = dict()
[docs]
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
):
if self.eval_val:
self._collate_outputs(outputs, batch, dataloader_idx)
[docs]
def on_test_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
):
self._collate_outputs(outputs, batch, dataloader_idx)
[docs]
def on_validation_epoch_end(self, trainer, pl_module) -> None:
if self.eval_val:
self._on_epoch_end(trainer, pl_module, "val")
[docs]
def on_test_epoch_end(self, trainer, pl_module) -> None:
self._on_epoch_end(trainer, pl_module, "test")
def _on_epoch_end(self, trainer, pl_module, step_name):
if self.logger is not None:
self.logger.info(
f"{step_name.capitalize()} segments: {len(self.full_outputs[0]['subject'])}"
)
self._compute_metrics(trainer, pl_module, step_name=step_name)
if self.save_outputs:
self._save_outputs(trainer, step_name=step_name)
del self.full_outputs
def _save_outputs(self, trainer, step_name):
n_loaders = self.get_n_loaders(trainer, step_name)
for dataloader_idx in range(n_loaders):
full = self.full_outputs[dataloader_idx]
for key in ["y_pred", "y_true"]:
full[key] = torch.cat(full[key], dim=0)
save_dir = Path(trainer.logger.save_dir) / "retrieval_outputs"
save_dir.mkdir(parents=True, exist_ok=True)
file = save_dir / f"{step_name}_{dataloader_idx}.pt"
torch.save(full, file)
def _collate_outputs(self, outputs, batch, dataloader_idx):
y_pred, y_true = outputs
full = self.full_outputs[dataloader_idx]
full["y_pred"].append(y_pred.cpu())
full["y_true"].append(y_true.cpu())
# for each prediction, we retrieve the corresponding extractor item (category for image
# datasets, text for word datasets)
for segment in batch.segments:
trigger = segment.trigger
if hasattr(trigger, self.event_field):
full[self.event_field].append(getattr(trigger, self.event_field))
elif hasattr(trigger, "category"):
full["category"].append(trigger.category)
else:
pass
extra = trigger.extra
full["subject"].append(extra["subject"])
if "sequence_id" in extra:
full["sequence"].append(extra["sequence_id"])
elif hasattr(trigger, "filepath"):
full["filepath"].append(trigger.filepath)
else:
pass
def _compute_metrics(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
step_name: tp.Literal["val", "test"],
) -> None:
n_loaders = self.get_n_loaders(trainer, step_name)
test_full_metrics = {
k: v
for k, v in pl_module.test_full_retrieval_metrics.items() # type: ignore
if (k.startswith(step_name) and "retrieval" in k)
}
for dataloader_idx in range(n_loaders):
full = self.full_outputs[dataloader_idx]
y_pred = torch.cat(full["y_pred"], dim=0)
y_true = torch.cat(full["y_true"], dim=0)
subjects_pred = full["subject"]
if "text" in full.keys():
groups_pred = full["text"]
else:
groups_pred = full["filepath"]
for retrieval_set_size in self.retrieval_set_sizes:
out = self._get_test_full_metrics(
y_pred=y_pred,
y_true=y_true,
groups_pred=groups_pred,
subjects_pred=subjects_pred,
metrics=test_full_metrics,
retrieval_set_size=retrieval_set_size,
step_name=step_name,
)
for key, value in out.items():
if n_loaders > 1:
key += f"_{dataloader_idx}"
pl_module.log(key, value)
if self.logger is not None:
self.logger.info(f"{key}: {value}")
@staticmethod
def _get_test_full_metrics(
y_pred: torch.Tensor,
y_true: torch.Tensor,
groups_pred: list,
subjects_pred: list,
metrics: dict[str, torchmetrics.Metric],
step_name: tp.Literal["val", "test"],
retrieval_set_size: int | None = None,
) -> dict[str, torch.Tensor]:
"""Compute retrieval metrics with optional aggregation of predictions.
This method processes the predictions and true labels to compute the specified metrics,
potentially aggregating results by groups or subjects.
Parameters
----------
y_pred:
predicted labels or scores from the model
y_true:
true labels for the data.
groups_pred:
group labels for each prediction.
subjects_pred:
subject id for each prediction.
metrics:
a dictionary of metric functions to be computed.
step_name:
the name of the current step (e.g. 'val', 'test').
retrieval_set_size:
the number of most frequent words to consider for metrics. Defaults to None.
Returns
-------
dict:
A dictionary containing computed metric values.
Note
----
The function can handle different aggregation strategies based on the end of the metric name:
- "subject-agg": Aggregates predictions by subject first for a given group.
- "instance-agg": Aggregates all predictions for a given group.
- "subject-ind": Saves the metrics for each subject individually.
"""
out = {}
# Keep only the most frequent groups
if retrieval_set_size is not None:
groups_df = pd.DataFrame({"label": groups_pred})
counts = groups_df.label.value_counts(sort=True, ascending=False)
most_frequent = set(counts.index[:retrieval_set_size])
if len(counts) < retrieval_set_size:
msg = f"The number of unique labels ({len(counts)}) is lower than the retrieval size ({retrieval_set_size})."
warnings.warn(msg, UserWarning)
msg = f"Retrieval set items (and count) in test set:\n{counts.iloc[:retrieval_set_size].to_dict()}"
LOGGER.warning(msg)
mask = groups_df.label.isin(most_frequent).to_numpy()
y_pred, y_true = y_pred[mask], y_true[mask]
indices = np.where(mask)[0]
groups_pred = [groups_pred[i] for i in indices]
subjects_pred = [subjects_pred[i] for i in indices]
# Remove repetitions in retrieval set
agg_y_true, agg_groups_true = agg_per_group(
y_true,
groups=groups_pred,
agg_func="mean", # Use "mean" instead of "first" because events can have different
# latent representation in speech decoding experiments
)
for metric_name, metric in tqdm(
metrics.items(), "Evaluating full test set retrieval metrics"
):
metric = metric.to("cpu")
subjects: list | None
if metric_name.endswith("subject-agg"):
subjects = subjects_pred
elif metric_name.endswith("instance-agg"):
subjects = None
else:
subjects = torch.arange(y_pred.shape[0], device=y_pred.device).tolist()
agg_y_pred, agg_groups_pred = agg_retrieval_preds(
y_pred,
groups_pred=groups_pred,
subjects_pred=subjects,
)
if retrieval_set_size is not None:
metric_name += f"_size={retrieval_set_size}"
if "subject-ind" in metric_name and step_name == "test":
# Compute performance per subject
individual_subjects = pd.DataFrame({"id": subjects_pred})
results = {}
for subj, grp in individual_subjects.groupby("id"):
metric.reset()
metric.update(
agg_y_pred[grp.index.values],
agg_y_true,
[
agg_grp_pred
for i, agg_grp_pred in enumerate(agg_groups_pred)
if i in grp.index
],
agg_groups_true,
)
results[subj] = metric.compute().item()
out[metric_name] = torch.tensor(np.mean(list(results.values())))
else:
metric.reset()
metric.update(agg_y_pred, agg_y_true, agg_groups_pred, agg_groups_true)
out[metric_name] = metric.compute()
if "subject-ind" in metric_name and "agg" not in metric_name:
# Compute frequency-corrected average
if hasattr(metric, "_compute_macro_average"):
ranks = metric._compute_ranks( # type: ignore
agg_y_pred, agg_y_true, agg_groups_pred, agg_groups_true
)
macro_average = np.mean(
list(
metric._compute_macro_average(ranks, agg_groups_pred).values() # type: ignore
)
)
out[metric_name + "_macro"] = torch.tensor(macro_average)
else:
warnings.warn(
f"Metric {metric} does not implement `_compute_macro_average`."
)
return out
[docs]
class RecordingLevelEval(Callback):
"""Callback to evaluate average prediction over each recording (timeline)."""
def __init__(self) -> None:
self.test_outputs: dict[str, dict[str, tp.Any]] = {}
self.num_classes: int | None = None
[docs]
def setup(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
) -> None:
if not hasattr(pl_module, "test_full_metrics") or not isinstance(
pl_module.test_full_metrics, nn.ModuleDict
):
raise ValueError(
"The LightningModule needs a test_full_metrics ModuleDict that contains the "
"metrics to evaluate on the full test set."
)
[docs]
def on_test_epoch_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
self.test_outputs = {}
self.num_classes = None
[docs]
def on_test_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
) -> None:
y_pred_logits = outputs[0].cpu()
y_pred = y_pred_logits.argmax(dim=1)
y_true = outputs[1].cpu()
timelines = [segment.events.timeline.iloc[0] for segment in batch.segments]
# Infer number of classes from model output shape
if self.num_classes is None:
self.num_classes = int(y_pred_logits.shape[1])
for timeline, y_true_, y_pred_ in zip(timelines, y_true, y_pred):
predicted_class = y_pred_.item()
# Initialize entry for new timeline
if timeline not in self.test_outputs:
self.test_outputs[timeline] = {
"y_true": y_true_.item(),
"class_counts": [0] * self.num_classes,
}
# Update class count for the predicted class
self.test_outputs[timeline]["class_counts"][predicted_class] += 1
[docs]
def on_test_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
outputs_df = pd.DataFrame(self.test_outputs).T
# Compute total count per recording
outputs_df["total_count"] = outputs_df["class_counts"].apply(sum)
# Compute probability for each class
assert isinstance(self.num_classes, int)
for i in range(self.num_classes):
outputs_df[f"y_pred{i}"] = outputs_df["class_counts"].apply(
lambda counts: counts[i] / sum(counts) if sum(counts) > 0 else 0.0
)
# Create prediction tensor with shape (n_recordings, num_classes)
pred_columns = [f"y_pred{i}" for i in range(self.num_classes)]
y_pred_probs = torch.from_numpy(outputs_df[pred_columns].values)
# Convert y_true to int64 explicitly to avoid object dtype issues
y_true = torch.tensor(outputs_df.y_true.tolist(), dtype=torch.int64)
for metric_name, metric in pl_module.test_full_metrics.items(): # type: ignore
metric.to("cpu")
metric.update(y_pred_probs, y_true)
pl_module.log(
metric_name,
metric,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
batch_size=outputs_df.shape[0],
)
[docs]
def plot_confusion_matrix(
cm: np.ndarray,
labels: list[str] | None,
cmap: str = "viridis",
xticks_rotation: float | str = 45,
) -> Figure:
"""
Plot confusion matrix/matrices with a shared colorbar.
Parameters
----------
cm : array-like
Confusion matrix. Can be:
- 2D array of shape (n_classes, n_classes) for multiclass case
- 3D array of shape (n_labels, 2, 2) for multilabel case
labels : list of str, optional
Labels for the classes/categories. If None, uses numeric indices.
cmap : str, default='viridis'
Colormap to use for the heatmap.
Returns
-------
fig : matplotlib.figure.Figure
The figure object.
"""
_set_plot_theme()
cm = np.asarray(cm)
# Detect if multiclass (2D) or multilabel (3D with 2x2 matrices)
is_multilabel = cm.ndim == 3 and cm.shape[1] == 2 and cm.shape[2] == 2
if is_multilabel:
# Multilabel case: multiple 2x2 confusion matrices
n_labels = cm.shape[0]
n_cols = int(np.ceil(np.sqrt(n_labels)))
n_rows = int(np.ceil(n_labels / n_cols))
fig, axes = plt.subplots(
n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows), constrained_layout=True
)
axes = np.atleast_1d(axes).flatten()
# Find global min/max for consistent colorbar
vmin = cm.min()
vmax = cm.max()
# Plot each confusion matrix using ConfusionMatrixDisplay
displays = []
for i in range(n_labels):
label = labels[i] if labels is not None else f"Label {i}"
disp = ConfusionMatrixDisplay(
confusion_matrix=cm[i], display_labels=["0", "1"]
)
disp.plot(
ax=axes[i],
cmap=cmap,
colorbar=False,
im_kw={"vmin": vmin, "vmax": vmax},
xticks_rotation=xticks_rotation,
)
axes[i].set_title(label)
displays.append(disp)
# Hide unused subplots
for i in range(n_labels, len(axes)):
axes[i].set_visible(False)
# Add a common colorbar (using list of visible axes only)
visible_axes = [axes[i] for i in range(n_labels)]
fig.colorbar(displays[0].im_, ax=visible_axes, label="Count")
else:
# Multiclass case: single confusion matrix
fig, ax = plt.subplots(figsize=(8, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(ax=ax, cmap=cmap, xticks_rotation=xticks_rotation)
ax.set_title("Confusion Matrix")
fig.tight_layout()
return fig
[docs]
class PlotConfusionMatrix(Callback):
def __init__(self, labels: list[str] | None = None):
self.labels = labels
def _plot_and_log(
self,
step_name: tp.Literal["val", "test"],
trainer,
pl_module,
):
if isinstance(trainer.logger, pl.loggers.wandb.WandbLogger):
import wandb
metric_name = step_name + "/confusion_matrix"
metric = getattr(pl_module.metrics, metric_name, None)
if metric is not None:
cm = metric.compute().detach().cpu().numpy()
cm_fig = plot_confusion_matrix(cm, labels=self.labels)
metric.reset()
try:
wandb.log({metric_name: wandb.Image(cm_fig)})
except wandb.errors.Error as e:
LOGGER.warning("Failed to log confusion matrix to W&B: %s", e)
plt.close("all")
[docs]
def on_validation_epoch_end(self, trainer, pl_module):
if trainer.is_global_zero:
self._plot_and_log("val", trainer, pl_module)
[docs]
def on_test_epoch_end(self, trainer, pl_module):
if trainer.is_global_zero:
self._plot_and_log("test", trainer, pl_module)
[docs]
class PlotRegressionVectors(Callback):
"""Visualize predictions vs ground truth for multi-dimensional regression tasks.
This callback samples examples from the test set and creates visualizations
showing the predicted and ground truth vectors overlaid for comparison.
Each sample is shown in a separate subplot with dimension index on x-axis
and magnitude on y-axis.
Parameters
----------
num_samples : int, default=10
Number of samples to visualize from the test set.
"""
def __init__(self, num_samples: int = 10):
self.num_samples = num_samples
self.samples_collected: list[tuple[torch.Tensor, torch.Tensor]] = []
self.sample_indices: set[int] | None = None
[docs]
def on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
"""Initialize sample collection at the start of testing."""
self.samples_collected = []
self.sample_indices = None
[docs]
def on_test_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
):
"""Collect samples uniformly across the test set."""
if len(self.samples_collected) >= self.num_samples:
return
y_pred, y_true = outputs
# On first batch, calculate which batch indices to sample from
if self.sample_indices is None and trainer.num_test_batches[dataloader_idx] > 0:
total_batches = trainer.num_test_batches[dataloader_idx]
# Calculate evenly spaced indices to collect exactly num_samples
self.sample_indices = set(
int(i * total_batches / self.num_samples) for i in range(self.num_samples)
)
# Collect samples at predetermined indices
if self.sample_indices is not None and batch_idx in self.sample_indices:
# Take the first sample from the batch
if len(self.samples_collected) < self.num_samples:
self.samples_collected.append(
(y_pred[0].detach().cpu(), y_true[0].detach().cpu())
)
[docs]
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
"""Generate and log visualization at the end of testing."""
if not trainer.is_global_zero:
return
if len(self.samples_collected) == 0:
LOGGER.warning("No samples collected for regression vector visualization")
return
# Only log to wandb
if not isinstance(trainer.logger, pl.loggers.wandb.WandbLogger):
return
import wandb
try:
fig = self._create_figure()
wandb.log({"test/regression_vectors": wandb.Image(fig)})
plt.close(fig)
except Exception as e:
LOGGER.error(f"Failed to create regression vector visualization: {e}")
plt.close("all")
def _create_figure(self) -> Figure:
"""Create the visualization figure with subplots for each sample."""
n_samples = len(self.samples_collected)
# Determine grid layout (prefer 2 columns)
n_cols = min(2, n_samples)
n_rows = int(np.ceil(n_samples / n_cols))
# Create figure with appropriate size
fig, axes = plt.subplots(
n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows), constrained_layout=True
)
# Handle single subplot case
if n_samples == 1:
axes = np.array([axes])
axes = axes.flatten() if n_samples > 1 else axes
for idx, (y_pred, y_true) in enumerate(self.samples_collected):
ax = axes[idx] if n_samples > 1 else axes[0]
# Convert tensors to numpy
y_pred_np = y_pred.numpy().flatten()
y_true_np = y_true.numpy().flatten()
# Create dimension indices
dimensions = np.arange(len(y_true_np))
# Plot ground truth and prediction
ax.plot(
dimensions,
y_true_np,
label="Ground Truth",
color="tab:blue",
linewidth=1.5,
)
ax.plot(
dimensions,
y_pred_np,
label="Prediction",
color="tab:orange",
linewidth=1.5,
alpha=0.8,
)
ax.set_xlabel("Dimension")
ax.set_ylabel("Magnitude")
ax.set_title(f"Sample {idx + 1}")
ax.legend(loc="best")
ax.grid(True, alpha=0.3)
# Hide unused subplots
for idx in range(n_samples, len(axes)):
axes[idx].set_visible(False)
return fig
class PlotRegressionScatter(Callback):
"""Scatter plot of predicted vs ground truth for 1D regression tasks.
Accumulates all test predictions, then produces a scatter plot with an
identity line, a linear fit, and annotations for slope, intercept,
Pearson r, and R^2. If the ``pl_module`` has a fitted ``target_scaler``,
predictions and targets are inverse-transformed back to original units
before plotting.
The figure is logged to W&B under ``"test/regression_scatter"``.
"""
def __init__(self):
self.y_preds: list[torch.Tensor] = []
self.y_trues: list[torch.Tensor] = []
def on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
self.y_preds = []
self.y_trues = []
def on_test_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
):
y_pred, y_true = outputs
self.y_preds.append(y_pred.detach().cpu())
self.y_trues.append(y_true.detach().cpu())
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
if not trainer.is_global_zero:
return
if len(self.y_preds) == 0:
LOGGER.warning("No samples collected for regression scatter visualization")
return
if not isinstance(trainer.logger, pl.loggers.wandb.WandbLogger):
return
import wandb
y_pred = torch.cat(self.y_preds).squeeze()
y_true = torch.cat(self.y_trues).squeeze()
target_scaler = getattr(pl_module, "target_scaler", None)
if target_scaler is not None and target_scaler._mean is not None:
y_pred = self._inverse_transform(y_pred, target_scaler)
y_true = self._inverse_transform(y_true, target_scaler)
try:
fig = self._create_figure(y_true.numpy().flatten(), y_pred.numpy().flatten())
wandb.log({"test/regression_scatter": wandb.Image(fig)})
plt.close(fig)
except Exception as e:
LOGGER.error(f"Failed to create regression scatter visualization: {e}")
plt.close("all")
@staticmethod
def _inverse_transform(
x: torch.Tensor, target_scaler: "StandardScaler"
) -> torch.Tensor:
"""Reverse a StandardScaler transform: x_orig = x * scale + mean."""
assert target_scaler._mean is not None and target_scaler._scale is not None
mean = target_scaler._mean.cpu()
scale = target_scaler._scale.cpu()
return x * scale + mean
@staticmethod
def _create_figure(y_true: np.ndarray, y_pred: np.ndarray) -> Figure:
fig, ax = plt.subplots(figsize=(6, 6), dpi=150, constrained_layout=True)
ax.scatter(y_true, y_pred, s=8, alpha=0.4, edgecolors="none")
lo = min(y_true.min(), y_pred.min())
hi = max(y_true.max(), y_pred.max())
margin = 0.05 * (hi - lo) if hi > lo else 1.0
ax.plot(
[lo - margin, hi + margin],
[lo - margin, hi + margin],
ls="--",
color="0.5",
lw=1,
label="Identity",
)
slope, intercept = np.polyfit(y_true, y_pred, 1)
fit_x = np.array([lo - margin, hi + margin])
ax.plot(fit_x, slope * fit_x + intercept, color="tab:red", lw=1.5, label="Fit")
r = np.corrcoef(y_true, y_pred)[0, 1]
r2 = r**2
textstr = (
f"slope = {slope:.3f}\n"
f"intercept = {intercept:.3f}\n"
f"r = {r:.3f}\n"
f"R² = {r2:.3f}"
)
ax.text(
0.05,
0.95,
textstr,
transform=ax.transAxes,
fontsize=9,
verticalalignment="top",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
ax.set_xlabel("Ground Truth")
ax.set_ylabel("Prediction")
ax.legend(loc="lower right")
ax.grid(True, alpha=0.3)
return fig