Module audiocraft.metrics.kld

Functions

def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-06) ‑> torch.Tensor
Expand source code
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
    """Computes the elementwise KL-Divergence loss between probability distributions
    from generated samples and target samples.

    Args:
        pred_probs (torch.Tensor): Probabilities for each label obtained
            from a classifier on generated audio. Expected shape is [B, num_classes].
        target_probs (torch.Tensor): Probabilities for each label obtained
            from a classifier on target audio. Expected shape is [B, num_classes].
        epsilon (float): Epsilon value.
    Returns:
        kld (torch.Tensor): KLD loss between each generated sample and target pair.
    """
    kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
    return kl_div.sum(-1)

Computes the elementwise KL-Divergence loss between probability distributions from generated samples and target samples.

Args

pred_probs : torch.Tensor
Probabilities for each label obtained from a classifier on generated audio. Expected shape is [B, num_classes].
target_probs : torch.Tensor
Probabilities for each label obtained from a classifier on target audio. Expected shape is [B, num_classes].
epsilon : float
Epsilon value.

Returns

kld (torch.Tensor): KLD loss between each generated sample and target pair.

Classes

class KLDivergenceMetric
Expand source code
class KLDivergenceMetric(torchmetrics.Metric):
    """Base implementation for KL Divergence metric.

    The KL divergence is measured between probability distributions
    of class predictions returned by a pre-trained audio classification model.
    When the KL-divergence is low, the generated audio is expected to
    have similar acoustic characteristics as the reference audio,
    according to the classifier.
    """
    def __init__(self):
        super().__init__()
        self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")

    def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
                                sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
        """Get model output given provided input tensor.

        Args:
            x (torch.Tensor): Input audio tensor of shape [B, C, T].
            sizes (torch.Tensor): Actual audio sample length, of shape [B].
            sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
        Returns:
            probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
        """
        raise NotImplementedError("implement method to extract label distributions from the model.")

    def update(self, preds: torch.Tensor, targets: torch.Tensor,
               sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
        """Calculates running KL-Divergence loss between batches of audio
        preds (generated) and target (ground-truth)
        Args:
            preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
            targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
            sizes (torch.Tensor): Actual audio sample length, of shape [B].
            sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
        """
        assert preds.shape == targets.shape
        assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
        preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
        targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
        if preds_probs is not None and targets_probs is not None:
            assert preds_probs.shape == targets_probs.shape
            kld_scores = kl_divergence(preds_probs, targets_probs)
            assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
            self.kld_pq_sum += torch.sum(kld_scores)
            kld_qp_scores = kl_divergence(targets_probs, preds_probs)
            self.kld_qp_sum += torch.sum(kld_qp_scores)
            self.weight += torch.tensor(kld_scores.size(0))

    def compute(self) -> dict:
        """Computes KL-Divergence across all evaluated pred/target pairs."""
        weight: float = float(self.weight.item())  # type: ignore
        assert weight > 0, "Unable to compute with total number of comparisons <= 0"
        logger.info(f"Computing KL divergence on a total of {weight} samples")
        kld_pq = self.kld_pq_sum.item() / weight  # type: ignore
        kld_qp = self.kld_qp_sum.item() / weight  # type: ignore
        kld_both = kld_pq + kld_qp
        return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}

Base implementation for KL Divergence metric.

The KL divergence is measured between probability distributions of class predictions returned by a pre-trained audio classification model. When the KL-divergence is low, the generated audio is expected to have similar acoustic characteristics as the reference audio, according to the classifier.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torchmetrics.metric.Metric
  • torch.nn.modules.module.Module
  • abc.ABC

Subclasses

Class variables

var full_state_update : bool | None
var higher_is_better : bool | None
var is_differentiable : bool | None
var plot_legend_name : str | None
var plot_lower_bound : float | None
var plot_upper_bound : float | None

Methods

def compute(self) ‑> dict
Expand source code
def compute(self) -> dict:
    """Computes KL-Divergence across all evaluated pred/target pairs."""
    weight: float = float(self.weight.item())  # type: ignore
    assert weight > 0, "Unable to compute with total number of comparisons <= 0"
    logger.info(f"Computing KL divergence on a total of {weight} samples")
    kld_pq = self.kld_pq_sum.item() / weight  # type: ignore
    kld_qp = self.kld_qp_sum.item() / weight  # type: ignore
    kld_both = kld_pq + kld_qp
    return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}

Computes KL-Divergence across all evaluated pred/target pairs.

def update(self,
preds: torch.Tensor,
targets: torch.Tensor,
sizes: torch.Tensor,
sample_rates: torch.Tensor) ‑> None
Expand source code
def update(self, preds: torch.Tensor, targets: torch.Tensor,
           sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
    """Calculates running KL-Divergence loss between batches of audio
    preds (generated) and target (ground-truth)
    Args:
        preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
        targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
        sizes (torch.Tensor): Actual audio sample length, of shape [B].
        sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
    """
    assert preds.shape == targets.shape
    assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
    preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
    targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
    if preds_probs is not None and targets_probs is not None:
        assert preds_probs.shape == targets_probs.shape
        kld_scores = kl_divergence(preds_probs, targets_probs)
        assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
        self.kld_pq_sum += torch.sum(kld_scores)
        kld_qp_scores = kl_divergence(targets_probs, preds_probs)
        self.kld_qp_sum += torch.sum(kld_qp_scores)
        self.weight += torch.tensor(kld_scores.size(0))

Calculates running KL-Divergence loss between batches of audio preds (generated) and target (ground-truth)

Args

preds : torch.Tensor
Audio samples to evaluate, of shape [B, C, T].
targets : torch.Tensor
Target samples to compare against, of shape [B, C, T].
sizes : torch.Tensor
Actual audio sample length, of shape [B].
sample_rates : torch.Tensor
Actual audio sample rate, of shape [B].
class PasstKLDivergenceMetric (pretrained_length: float | None = None)
Expand source code
class PasstKLDivergenceMetric(KLDivergenceMetric):
    """KL-Divergence metric based on pre-trained PASST classifier on AudioSet.

    From: PaSST: Efficient Training of Audio Transformers with Patchout
    Paper: https://arxiv.org/abs/2110.05069
    Implementation: https://github.com/kkoutini/PaSST

    Follow instructions from the github repo:
    ```
    pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
    ```

    Args:
        pretrained_length (float, optional): Audio duration used for the pretrained model.
    """
    def __init__(self, pretrained_length: tp.Optional[float] = None):
        super().__init__()
        self._initialize_model(pretrained_length)

    def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
        """Initialize underlying PaSST audio classifier."""
        model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
        self.min_input_frames = min_frames
        self.max_input_frames = max_frames
        self.model_sample_rate = sr
        self.model = model
        self.model.eval()
        self.model.to(self.device)

    def _load_base_model(self, pretrained_length: tp.Optional[float]):
        """Load pretrained model from PaSST."""
        try:
            if pretrained_length == 30:
                from hear21passt.base30sec import get_basic_model  # type: ignore
                max_duration = 30
            elif pretrained_length == 20:
                from hear21passt.base20sec import get_basic_model  # type: ignore
                max_duration = 20
            else:
                from hear21passt.base import get_basic_model  # type: ignore
                # Original PASST was trained on AudioSet with 10s-long audio samples
                max_duration = 10
            min_duration = 0.15
            min_duration = 0.15
        except ModuleNotFoundError:
            raise ModuleNotFoundError(
                "Please install hear21passt to compute KL divergence: ",
                "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
            )
        model_sample_rate = 32_000
        max_input_frames = int(max_duration * model_sample_rate)
        min_input_frames = int(min_duration * model_sample_rate)
        with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
            model = get_basic_model(mode='logits')
        return model, model_sample_rate, max_input_frames, min_input_frames

    def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
        """Process audio to feed to the pretrained model."""
        wav = wav.unsqueeze(0)
        wav = wav[..., :wav_len]
        wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
        wav = wav.squeeze(0)
        # we don't pad but return a list of audio segments as this otherwise affects the KLD computation
        segments = torch.split(wav, self.max_input_frames, dim=-1)
        valid_segments = []
        for s in segments:
            # ignoring too small segments that are breaking the model inference
            if s.size(-1) > self.min_input_frames:
                valid_segments.append(s)
        return [s[None] for s in valid_segments]

    def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
        """Run the pretrained model and get the predictions."""
        assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
        wav = wav.mean(dim=1)
        # PaSST is printing a lot of garbage that we are not interested in
        with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
            with torch.no_grad(), _patch_passt_stft():
                logits = self.model(wav.to(self.device))
                probs = torch.softmax(logits, dim=-1)
                return probs

    def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
                                sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
        """Get model output given provided input tensor.

        Args:
            x (torch.Tensor): Input audio tensor of shape [B, C, T].
            sizes (torch.Tensor): Actual audio sample length, of shape [B].
            sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
        Returns:
            probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
        """
        all_probs: tp.List[torch.Tensor] = []
        for i, wav in enumerate(x):
            sample_rate = int(sample_rates[i].item())
            wav_len = int(sizes[i].item())
            wav_segments = self._process_audio(wav, sample_rate, wav_len)
            for segment in wav_segments:
                probs = self._get_model_preds(segment).mean(dim=0)
                all_probs.append(probs)
        if len(all_probs) > 0:
            return torch.stack(all_probs, dim=0)
        else:
            return None

KL-Divergence metric based on pre-trained PASST classifier on AudioSet.

From: PaSST: Efficient Training of Audio Transformers with Patchout Paper: https://arxiv.org/abs/2110.05069 Implementation: https://github.com/kkoutini/PaSST

Follow instructions from the github repo:

pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'

Args

pretrained_length : float, optional
Audio duration used for the pretrained model.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

Class variables

var full_state_update : bool | None
var higher_is_better : bool | None
var is_differentiable : bool | None
var plot_legend_name : str | None
var plot_lower_bound : float | None
var plot_upper_bound : float | None

Inherited members