Module audiocraft.metrics.clap_consistency
Classes
class CLAPTextConsistencyMetric (model_path: str | pathlib.Path,
model_arch: str = 'HTSAT-tiny',
enable_fusion: bool = False)-
Expand source code
class CLAPTextConsistencyMetric(TextConsistencyMetric): """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP). This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf). As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as well as the generated audio based on them, and define the MCC metric as the average cosine similarity between these embeddings. Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP """ def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False): super().__init__() if laion_clap is None: raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'") self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") self._initialize_model(model_path, model_arch, enable_fusion) def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool): model_path = AudioCraftEnvironment.resolve_reference_path(model_path) self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) self.model_sample_rate = 48_000 load_clap_state_dict(self.model, model_path) self.model.eval() def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: # we use the default params from CLAP module here as well return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.""" assert audio.size(0) == len(text), "Number of audio and text samples should match" assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate" sample_rate = int(sample_rates[0].item()) # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T] audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1) audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True) text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) # cosine similarity between the text and the audio embedding cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8) self.cosine_sum += cosine_sim.sum(dim=0) self.weight += torch.tensor(cosine_sim.size(0)) def compute(self): """Computes the average cosine similarty across all audio/text pairs.""" assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore return (self.cosine_sum / self.weight).item() # type: ignore
Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as well as the generated audio based on them, and define the MCC metric as the average cosine similarity between these embeddings.
Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- TextConsistencyMetric
- torchmetrics.metric.Metric
- torch.nn.modules.module.Module
- abc.ABC
Class variables
var full_state_update : bool | None
var higher_is_better : bool | None
var is_differentiable : bool | None
var plot_legend_name : str | None
var plot_lower_bound : float | None
var plot_upper_bound : float | None
Methods
def compute(self)
-
Expand source code
def compute(self): """Computes the average cosine similarty across all audio/text pairs.""" assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore return (self.cosine_sum / self.weight).item() # type: ignore
Computes the average cosine similarty across all audio/text pairs.
def update(self,
audio: torch.Tensor,
text: List[str],
sizes: torch.Tensor,
sample_rates: torch.Tensor) ‑> None-
Expand source code
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.""" assert audio.size(0) == len(text), "Number of audio and text samples should match" assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate" sample_rate = int(sample_rates[0].item()) # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T] audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1) audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True) text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) # cosine similarity between the text and the audio embedding cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8) self.cosine_sum += cosine_sim.sum(dim=0) self.weight += torch.tensor(cosine_sim.size(0))
Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.
class TextConsistencyMetric (**kwargs: Any)
-
Expand source code
class TextConsistencyMetric(torchmetrics.Metric): """Text consistency metric measuring consistency between audio and text pairs.""" def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: raise NotImplementedError("implement how to update the metric from the audio and text pairs.") def compute(self): raise NotImplementedError("implement how to compute the final metric score.")
Text consistency metric measuring consistency between audio and text pairs.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torchmetrics.metric.Metric
- torch.nn.modules.module.Module
- abc.ABC
Subclasses
Class variables
var full_state_update : bool | None
var higher_is_better : bool | None
var is_differentiable : bool | None
var plot_legend_name : str | None
var plot_lower_bound : float | None
var plot_upper_bound : float | None
Methods
def compute(self)
-
Expand source code
def compute(self): raise NotImplementedError("implement how to compute the final metric score.")
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
def update(self,
audio: torch.Tensor,
text: List[str],
sizes: torch.Tensor,
sample_rates: torch.Tensor) ‑> None-
Expand source code
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
Override this method to update the state variables of your metric class.