Module audiocraft.metrics.pesq
Classes
class PesqMetric (sample_rate: int)-
Expand source code
class PesqMetric(torchmetrics.Metric): """Metric for Perceptual Evaluation of Speech Quality. (https://doi.org/10.5281/zenodo.6549559) """ sum_pesq: torch.Tensor total: torch.Tensor def __init__(self, sample_rate: int): super().__init__() self.sr = sample_rate self.add_state("sum_pesq", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, targets: torch.Tensor): if self.sr != 16000: preds = julius.resample_frac(preds, self.sr, 16000) targets = julius.resample_frac(targets, self.sr, 16000) for ii in range(preds.size(0)): try: self.sum_pesq += pesq.pesq( 16000, targets[ii, 0].detach().cpu().numpy(), preds[ii, 0].detach().cpu().numpy() ) self.total += 1 except ( pesq.NoUtterancesError ): # this error can append when the sample don't contain speech pass def compute(self) -> torch.Tensor: return ( self.sum_pesq / self.total if (self.total != 0).item() else torch.tensor(0.0) )Metric for Perceptual Evaluation of Speech Quality. (https://doi.org/10.5281/zenodo.6549559)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torchmetrics.metric.Metric
- torch.nn.modules.module.Module
- abc.ABC
Class variables
var sum_pesq : torch.Tensorvar total : torch.Tensor
Methods
def compute(self) ‑> torch.Tensor-
Expand source code
def compute(self) -> torch.Tensor: return ( self.sum_pesq / self.total if (self.total != 0).item() else torch.tensor(0.0) )Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
def update(self, preds: torch.Tensor, targets: torch.Tensor)-
Expand source code
def update(self, preds: torch.Tensor, targets: torch.Tensor): if self.sr != 16000: preds = julius.resample_frac(preds, self.sr, 16000) targets = julius.resample_frac(targets, self.sr, 16000) for ii in range(preds.size(0)): try: self.sum_pesq += pesq.pesq( 16000, targets[ii, 0].detach().cpu().numpy(), preds[ii, 0].detach().cpu().numpy() ) self.total += 1 except ( pesq.NoUtterancesError ): # this error can append when the sample don't contain speech passOverride this method to update the state variables of your metric class.