Module audiocraft.metrics.pesq

Classes

class PesqMetric (sample_rate: int)
Expand source code
class PesqMetric(torchmetrics.Metric):
    """Metric for Perceptual Evaluation of Speech Quality.
    (https://doi.org/10.5281/zenodo.6549559)

    """

    sum_pesq: torch.Tensor
    total: torch.Tensor

    def __init__(self, sample_rate: int):
        super().__init__()
        self.sr = sample_rate

        self.add_state("sum_pesq", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, targets: torch.Tensor):
        if self.sr != 16000:
            preds = julius.resample_frac(preds, self.sr, 16000)
            targets = julius.resample_frac(targets, self.sr, 16000)
        for ii in range(preds.size(0)):
            try:
                self.sum_pesq += pesq.pesq(
                    16000, targets[ii, 0].detach().cpu().numpy(), preds[ii, 0].detach().cpu().numpy()
                )
                self.total += 1
            except (
                pesq.NoUtterancesError
            ):  # this error can append when the sample don't contain speech
                pass

    def compute(self) -> torch.Tensor:
        return (
            self.sum_pesq / self.total
            if (self.total != 0).item()
            else torch.tensor(0.0)
        )

Metric for Perceptual Evaluation of Speech Quality. (https://doi.org/10.5281/zenodo.6549559)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torchmetrics.metric.Metric
  • torch.nn.modules.module.Module
  • abc.ABC

Class variables

var sum_pesq : torch.Tensor
var total : torch.Tensor

Methods

def compute(self) ‑> torch.Tensor
Expand source code
def compute(self) -> torch.Tensor:
    return (
        self.sum_pesq / self.total
        if (self.total != 0).item()
        else torch.tensor(0.0)
    )

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

def update(self, preds: torch.Tensor, targets: torch.Tensor)
Expand source code
def update(self, preds: torch.Tensor, targets: torch.Tensor):
    if self.sr != 16000:
        preds = julius.resample_frac(preds, self.sr, 16000)
        targets = julius.resample_frac(targets, self.sr, 16000)
    for ii in range(preds.size(0)):
        try:
            self.sum_pesq += pesq.pesq(
                16000, targets[ii, 0].detach().cpu().numpy(), preds[ii, 0].detach().cpu().numpy()
            )
            self.total += 1
        except (
            pesq.NoUtterancesError
        ):  # this error can append when the sample don't contain speech
            pass

Override this method to update the state variables of your metric class.