neuralbench.callbacks.TestFullRetrievalMetrics

class neuralbench.callbacks.TestFullRetrievalMetrics(event_type: Literal['Word', 'Image'] = 'Word', event_field: Literal['text', 'category'] = 'text', retrieval_set_sizes: tuple = (None, 250), save_outputs: bool = False, logger: Any = None, eval_val: bool = False)[source][source]

Accumulate predictions on entire test set before evaluating metrics.

Requires the pl.LightningModule object this callback is attached to to define a torch.ModuleDict named retrieval_metrics containing the metrics to evaluate.

on_test_batch_end(trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0)[source][source]

Called when the test batch ends.

on_test_epoch_end(trainer, pl_module) None[source][source]

Called when the test epoch ends.

on_test_epoch_start(trainer: Trainer, pl_module: LightningModule)[source][source]

Called when the test epoch begins.

on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0)[source][source]

Called when the validation batch ends.

on_validation_epoch_end(trainer, pl_module) None[source][source]

Called when the val epoch ends.

on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule)[source][source]

Called when the val epoch begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str)[source][source]

Called when fit, validate, test, predict, or tune begins.