neuraltrain.metrics.metrics.GroupedMetric¶
- class neuraltrain.metrics.metrics.GroupedMetric(metric_name: str, kwargs: dict[str, Any] | None = None)[source][source]¶
A wrapper around a torchmetrics.Metric that allows for computing metrics per group. IMPORTANT: this metric does not work well with LightningModule, because the self.log() method does not support dictionaries of metrics.
- To use this metric, you need to add this in the on_val_epoch_end and on_test_epoch_end methods:
metric_dict = {metric_name + “/” + k: v for k, v in grouped_metric.compute().items()} self.log_dict(metric_dict) grouped_metric.reset()