Module audiocraft.metrics.kld
Functions
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-06) ‑> torch.Tensor-
Expand source code
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor: """Computes the elementwise KL-Divergence loss between probability distributions from generated samples and target samples. Args: pred_probs (torch.Tensor): Probabilities for each label obtained from a classifier on generated audio. Expected shape is [B, num_classes]. target_probs (torch.Tensor): Probabilities for each label obtained from a classifier on target audio. Expected shape is [B, num_classes]. epsilon (float): Epsilon value. Returns: kld (torch.Tensor): KLD loss between each generated sample and target pair. """ kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none") return kl_div.sum(-1)Computes the elementwise KL-Divergence loss between probability distributions from generated samples and target samples.
Args
pred_probs:torch.Tensor- Probabilities for each label obtained from a classifier on generated audio. Expected shape is [B, num_classes].
target_probs:torch.Tensor- Probabilities for each label obtained from a classifier on target audio. Expected shape is [B, num_classes].
epsilon:float- Epsilon value.
Returns
kld (torch.Tensor): KLD loss between each generated sample and target pair.
Classes
class KLDivergenceMetric-
Expand source code
class KLDivergenceMetric(torchmetrics.Metric): """Base implementation for KL Divergence metric. The KL divergence is measured between probability distributions of class predictions returned by a pre-trained audio classification model. When the KL-divergence is low, the generated audio is expected to have similar acoustic characteristics as the reference audio, according to the classifier. """ def __init__(self): super().__init__() self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum") def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: """Get model output given provided input tensor. Args: x (torch.Tensor): Input audio tensor of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. Returns: probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes]. """ raise NotImplementedError("implement method to extract label distributions from the model.") def update(self, preds: torch.Tensor, targets: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: """Calculates running KL-Divergence loss between batches of audio preds (generated) and target (ground-truth) Args: preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T]. targets (torch.Tensor): Target samples to compare against, of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. """ assert preds.shape == targets.shape assert preds.size(0) > 0, "Cannot update the loss with empty tensors" preds_probs = self._get_label_distribution(preds, sizes, sample_rates) targets_probs = self._get_label_distribution(targets, sizes, sample_rates) if preds_probs is not None and targets_probs is not None: assert preds_probs.shape == targets_probs.shape kld_scores = kl_divergence(preds_probs, targets_probs) assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!" self.kld_pq_sum += torch.sum(kld_scores) kld_qp_scores = kl_divergence(targets_probs, preds_probs) self.kld_qp_sum += torch.sum(kld_qp_scores) self.weight += torch.tensor(kld_scores.size(0)) def compute(self) -> dict: """Computes KL-Divergence across all evaluated pred/target pairs.""" weight: float = float(self.weight.item()) # type: ignore assert weight > 0, "Unable to compute with total number of comparisons <= 0" logger.info(f"Computing KL divergence on a total of {weight} samples") kld_pq = self.kld_pq_sum.item() / weight # type: ignore kld_qp = self.kld_qp_sum.item() / weight # type: ignore kld_both = kld_pq + kld_qp return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}Base implementation for KL Divergence metric.
The KL divergence is measured between probability distributions of class predictions returned by a pre-trained audio classification model. When the KL-divergence is low, the generated audio is expected to have similar acoustic characteristics as the reference audio, according to the classifier.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torchmetrics.metric.Metric
- torch.nn.modules.module.Module
- abc.ABC
Subclasses
Class variables
var full_state_update : bool | Nonevar higher_is_better : bool | Nonevar is_differentiable : bool | Nonevar plot_legend_name : str | Nonevar plot_lower_bound : float | Nonevar plot_upper_bound : float | None
Methods
def compute(self) ‑> dict-
Expand source code
def compute(self) -> dict: """Computes KL-Divergence across all evaluated pred/target pairs.""" weight: float = float(self.weight.item()) # type: ignore assert weight > 0, "Unable to compute with total number of comparisons <= 0" logger.info(f"Computing KL divergence on a total of {weight} samples") kld_pq = self.kld_pq_sum.item() / weight # type: ignore kld_qp = self.kld_qp_sum.item() / weight # type: ignore kld_both = kld_pq + kld_qp return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}Computes KL-Divergence across all evaluated pred/target pairs.
def update(self,
preds: torch.Tensor,
targets: torch.Tensor,
sizes: torch.Tensor,
sample_rates: torch.Tensor) ‑> None-
Expand source code
def update(self, preds: torch.Tensor, targets: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: """Calculates running KL-Divergence loss between batches of audio preds (generated) and target (ground-truth) Args: preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T]. targets (torch.Tensor): Target samples to compare against, of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. """ assert preds.shape == targets.shape assert preds.size(0) > 0, "Cannot update the loss with empty tensors" preds_probs = self._get_label_distribution(preds, sizes, sample_rates) targets_probs = self._get_label_distribution(targets, sizes, sample_rates) if preds_probs is not None and targets_probs is not None: assert preds_probs.shape == targets_probs.shape kld_scores = kl_divergence(preds_probs, targets_probs) assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!" self.kld_pq_sum += torch.sum(kld_scores) kld_qp_scores = kl_divergence(targets_probs, preds_probs) self.kld_qp_sum += torch.sum(kld_qp_scores) self.weight += torch.tensor(kld_scores.size(0))Calculates running KL-Divergence loss between batches of audio preds (generated) and target (ground-truth)
Args
preds:torch.Tensor- Audio samples to evaluate, of shape [B, C, T].
targets:torch.Tensor- Target samples to compare against, of shape [B, C, T].
sizes:torch.Tensor- Actual audio sample length, of shape [B].
sample_rates:torch.Tensor- Actual audio sample rate, of shape [B].
class PasstKLDivergenceMetric (pretrained_length: float | None = None)-
Expand source code
class PasstKLDivergenceMetric(KLDivergenceMetric): """KL-Divergence metric based on pre-trained PASST classifier on AudioSet. From: PaSST: Efficient Training of Audio Transformers with Patchout Paper: https://arxiv.org/abs/2110.05069 Implementation: https://github.com/kkoutini/PaSST Follow instructions from the github repo: ``` pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt' ``` Args: pretrained_length (float, optional): Audio duration used for the pretrained model. """ def __init__(self, pretrained_length: tp.Optional[float] = None): super().__init__() self._initialize_model(pretrained_length) def _initialize_model(self, pretrained_length: tp.Optional[float] = None): """Initialize underlying PaSST audio classifier.""" model, sr, max_frames, min_frames = self._load_base_model(pretrained_length) self.min_input_frames = min_frames self.max_input_frames = max_frames self.model_sample_rate = sr self.model = model self.model.eval() self.model.to(self.device) def _load_base_model(self, pretrained_length: tp.Optional[float]): """Load pretrained model from PaSST.""" try: if pretrained_length == 30: from hear21passt.base30sec import get_basic_model # type: ignore max_duration = 30 elif pretrained_length == 20: from hear21passt.base20sec import get_basic_model # type: ignore max_duration = 20 else: from hear21passt.base import get_basic_model # type: ignore # Original PASST was trained on AudioSet with 10s-long audio samples max_duration = 10 min_duration = 0.15 min_duration = 0.15 except ModuleNotFoundError: raise ModuleNotFoundError( "Please install hear21passt to compute KL divergence: ", "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'" ) model_sample_rate = 32_000 max_input_frames = int(max_duration * model_sample_rate) min_input_frames = int(min_duration * model_sample_rate) with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): model = get_basic_model(mode='logits') return model, model_sample_rate, max_input_frames, min_input_frames def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]: """Process audio to feed to the pretrained model.""" wav = wav.unsqueeze(0) wav = wav[..., :wav_len] wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1) wav = wav.squeeze(0) # we don't pad but return a list of audio segments as this otherwise affects the KLD computation segments = torch.split(wav, self.max_input_frames, dim=-1) valid_segments = [] for s in segments: # ignoring too small segments that are breaking the model inference if s.size(-1) > self.min_input_frames: valid_segments.append(s) return [s[None] for s in valid_segments] def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor: """Run the pretrained model and get the predictions.""" assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}" wav = wav.mean(dim=1) # PaSST is printing a lot of garbage that we are not interested in with open(os.devnull, "w") as f, contextlib.redirect_stdout(f): with torch.no_grad(), _patch_passt_stft(): logits = self.model(wav.to(self.device)) probs = torch.softmax(logits, dim=-1) return probs def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: """Get model output given provided input tensor. Args: x (torch.Tensor): Input audio tensor of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. Returns: probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes]. """ all_probs: tp.List[torch.Tensor] = [] for i, wav in enumerate(x): sample_rate = int(sample_rates[i].item()) wav_len = int(sizes[i].item()) wav_segments = self._process_audio(wav, sample_rate, wav_len) for segment in wav_segments: probs = self._get_model_preds(segment).mean(dim=0) all_probs.append(probs) if len(all_probs) > 0: return torch.stack(all_probs, dim=0) else: return NoneKL-Divergence metric based on pre-trained PASST classifier on AudioSet.
From: PaSST: Efficient Training of Audio Transformers with Patchout Paper: https://arxiv.org/abs/2110.05069 Implementation: https://github.com/kkoutini/PaSST
Follow instructions from the github repo:
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'Args
pretrained_length:float, optional- Audio duration used for the pretrained model.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- KLDivergenceMetric
- torchmetrics.metric.Metric
- torch.nn.modules.module.Module
- abc.ABC
Class variables
var full_state_update : bool | Nonevar higher_is_better : bool | Nonevar is_differentiable : bool | Nonevar plot_legend_name : str | Nonevar plot_lower_bound : float | Nonevar plot_upper_bound : float | None
Inherited members