neuraltrain.metrics.metrics.ImageSimilarity

class neuraltrain.metrics.metrics.ImageSimilarity(model_name: Literal['inceptionv3', 'alexnet', 'clip', 'efficientnet', 'swav'] = 'inceptionv3', layer: str | int = 'avgpool', torchmetrics_kwargs: dict[str, Any] | None = None)[source][source]

Image similarity metric based on feature extraction from a pretrained network.

Code adapted from: https://github.com/ozcelikfu/brain-diffuser/blob/main/scripts/evaluate_reconstruction.py https://github.com/ozcelikfu/brain-diffuser/blob/main/scripts/eval_extract_features.py

Parameters:
  • model_name ({"inceptionv3", "alexnet", "clip", "efficientnet", "swav"}) – Pretrained network used for feature extraction.

  • layer (str or int) – Layer of the network to extract features from. Valid values depend on model_name.

  • torchmetrics_kwargs (dict or None) – Extra keyword arguments forwarded to the torchmetrics.Metric constructor.

compute() Tensor[source][source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

update(preds: Tensor, trues: Tensor) None[source][source]

Update internal list of ranks.

Parameters:
  • preds – Tensor of predictions, of shape (N, 3, H, W).

  • trues – Tensor of retrieval set examples, of shape (N, 3, H’,W’).