neuralbench.callbacks.TestFullRetrievalMetrics¶
- class neuralbench.callbacks.TestFullRetrievalMetrics(event_type: Literal['Word', 'Image'] = 'Word', event_field: Literal['text', 'category'] = 'text', retrieval_set_sizes: tuple = (None, 250), save_outputs: bool = False, logger: Any = None, eval_val: bool = False)[source][source]¶
Accumulate predictions on entire test set before evaluating metrics.
Requires the pl.LightningModule object this callback is attached to to define a torch.ModuleDict named retrieval_metrics containing the metrics to evaluate.
- on_test_batch_end(trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0)[source][source]¶
Called when the test batch ends.
- on_test_epoch_start(trainer: Trainer, pl_module: LightningModule)[source][source]¶
Called when the test epoch begins.
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0)[source][source]¶
Called when the validation batch ends.