neuraltrain.metrics.metrics.GroupedMetric

class neuraltrain.metrics.metrics.GroupedMetric(metric_name: str, kwargs: dict[str, Any] | None = None)[source][source]

A wrapper around a torchmetrics.Metric that allows for computing metrics per group. IMPORTANT: this metric does not work well with LightningModule, because the self.log() method does not support dictionaries of metrics.

To use this metric, you need to add this in the on_val_epoch_end and on_test_epoch_end methods:

metric_dict = {metric_name + “/” + k: v for k, v in grouped_metric.compute().items()} self.log_dict(metric_dict) grouped_metric.reset()

compute() dict[str, float][source][source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

reset() None[source][source]

Reset metric state variables to their default value.

update(preds: Tensor, target: Tensor, groups: Tensor | None = None) None[source][source]

Update each group’s metric separately. groups: a tensor or list of group identifiers, same shape as preds/target.