neuralbench.callbacks.PlotRegressionVectors

class neuralbench.callbacks.PlotRegressionVectors(num_samples: int = 10)[source][source]

Visualize predictions vs ground truth for multi-dimensional regression tasks.

This callback samples examples from the test set and creates visualizations showing the predicted and ground truth vectors overlaid for comparison. Each sample is shown in a separate subplot with dimension index on x-axis and magnitude on y-axis.

Parameters:

num_samples (int, default=10) – Number of samples to visualize from the test set.

on_test_batch_end(trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0)[source][source]

Collect samples uniformly across the test set.

on_test_epoch_end(trainer: Trainer, pl_module: LightningModule)[source][source]

Generate and log visualization at the end of testing.

on_test_epoch_start(trainer: Trainer, pl_module: LightningModule)[source][source]

Initialize sample collection at the start of testing.