Module audiocraft.losses.loudnessloss

Functions

def basic_loudness(waveform: torch.Tensor, sample_rate: int) ‑> torch.Tensor
Expand source code
def basic_loudness(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
    """This is a simpler loudness function that is more stable.
    Args:
        waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)`
        sample_rate (int): sampling rate of the waveform
    Returns:
        loudness loss as a scalar
    """

    if waveform.size(-2) > 5:
        raise ValueError("Only up to 5 channels are supported.")
    eps = torch.finfo(torch.float32).eps
    gate_duration = 0.4
    overlap = 0.75
    gate_samples = int(round(gate_duration * sample_rate))
    step = int(round(gate_samples * (1 - overlap)))

    # Apply K-weighting
    waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2))
    waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5)

    # Compute the energy for each block
    energy = torch.square(waveform).unfold(-1, gate_samples, step)
    energy = torch.mean(energy, dim=-1)

    # Compute channel-weighted summation
    g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device)
    g = g[: energy.size(-2)]

    energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2)
    # loudness with epsilon for stability. Not as much precision in the very low loudness sections
    loudness = -0.691 + 10 * torch.log10(energy_weighted + eps)
    return loudness

This is a simpler loudness function that is more stable.

Args

waveform(torch.Tensor): audio waveform of dimension (…, channels, time)
sample_rate : int
sampling rate of the waveform

Returns

loudness loss as a scalar

Classes

class FLoudnessRatio (sample_rate: int = 16000,
segment: float | None = 20,
overlap: float = 0.5,
epsilon: float = 1.1920928955078125e-07,
n_bands: int = 0)
Expand source code
class FLoudnessRatio(nn.Module):
    """FSNR loss.

    Input should be [B, C, T], output is scalar.

    Args:
        sample_rate (int): Sample rate.
        segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
            entire audio only.
        overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
        epsilon (float): Epsilon value for numerical stability.
        n_bands (int): number of mel scale bands that we include
    """
    def __init__(
        self,
        sample_rate: int = 16000,
        segment: tp.Optional[float] = 20,
        overlap: float = 0.5,
        epsilon: float = torch.finfo(torch.float32).eps,
        n_bands: int = 0,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.segment = segment
        self.overlap = overlap
        self.epsilon = epsilon
        if n_bands == 0:
            self.filter = None
        else:
            self.filter = julius.SplitBands(sample_rate=sample_rate, n_bands=n_bands)
        self.loudness = torchaudio.transforms.Loudness(sample_rate)

    def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
        B, C, T = ref_sig.shape
        assert ref_sig.shape == out_sig.shape
        assert self.filter is not None
        bands_ref = self.filter(ref_sig)
        bands_out = self.filter(out_sig)
        l_noise = self.loudness(bands_ref - bands_out)
        l_ref = self.loudness(bands_ref)
        l_ratio = (l_noise - l_ref).view(-1, B)
        loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio
        return loss.sum()

FSNR loss.

Input should be [B, C, T], output is scalar.

Args

sample_rate : int
Sample rate.
segment : float or None
Evaluate on chunks of that many seconds. If None, evaluate on entire audio only.
overlap : float
Overlap between chunks, i.e. 0.5 = 50 % overlap.
epsilon : float
Epsilon value for numerical stability.
n_bands : int
number of mel scale bands that we include

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) ‑> torch.Tensor
Expand source code
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
    B, C, T = ref_sig.shape
    assert ref_sig.shape == out_sig.shape
    assert self.filter is not None
    bands_ref = self.filter(ref_sig)
    bands_out = self.filter(out_sig)
    l_noise = self.loudness(bands_ref - bands_out)
    l_ref = self.loudness(bands_ref)
    l_ratio = (l_noise - l_ref).view(-1, B)
    loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio
    return loss.sum()

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class TFLoudnessRatio (sample_rate: int = 16000,
segment: float = 0.5,
overlap: float = 0.5,
n_bands: int = 0,
clip_min: float = -100,
temperature: float = 1.0)
Expand source code
class TFLoudnessRatio(nn.Module):
    """TF-loudness ratio loss.

    Input should be [B, C, T], output is scalar.

    Args:
        sample_rate (int): Sample rate.
        segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
            entire audio only.
        overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
        n_bands (int): number of bands to separate
        temperature (float): temperature of the softmax step
    """
    def __init__(
        self,
        sample_rate: int = 16000,
        segment: float = 0.5,
        overlap: float = 0.5,
        n_bands: int = 0,
        clip_min: float = -100,
        temperature: float = 1.0,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.segment = segment
        self.overlap = overlap
        self.clip_min = clip_min
        self.temperature = temperature
        if n_bands == 0:
            self.filter = None
        else:
            self.n_bands = n_bands
            self.filter = julius.SplitBands(sample_rate=sample_rate, n_bands=n_bands)

    def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
        B, C, T = ref_sig.shape

        assert ref_sig.shape == out_sig.shape
        assert C == 1
        assert self.filter is not None

        bands_ref = self.filter(ref_sig).view(B * self.n_bands, 1, -1)
        bands_out = self.filter(out_sig).view(B * self.n_bands, 1, -1)
        frame = int(self.segment * self.sample_rate)
        stride = int(frame * (1 - self.overlap))
        gt = _unfold(bands_ref, frame, stride).squeeze(1).contiguous().view(-1, 1, frame)
        est = _unfold(bands_out, frame, stride).squeeze(1).contiguous().view(-1, 1, frame)
        l_noise = basic_loudness(est - gt, sample_rate=self.sample_rate)  # watermark
        l_ref = basic_loudness(gt, sample_rate=self.sample_rate)  # ground truth
        l_ratio = (l_noise - l_ref).view(-1, B)
        loss = torch.nn.functional.softmax(l_ratio / self.temperature, dim=0) * l_ratio
        return loss.mean()

TF-loudness ratio loss.

Input should be [B, C, T], output is scalar.

Args

sample_rate : int
Sample rate.
segment : float or None
Evaluate on chunks of that many seconds. If None, evaluate on entire audio only.
overlap : float
Overlap between chunks, i.e. 0.5 = 50 % overlap.
n_bands : int
number of bands to separate
temperature : float
temperature of the softmax step

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) ‑> torch.Tensor
Expand source code
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
    B, C, T = ref_sig.shape

    assert ref_sig.shape == out_sig.shape
    assert C == 1
    assert self.filter is not None

    bands_ref = self.filter(ref_sig).view(B * self.n_bands, 1, -1)
    bands_out = self.filter(out_sig).view(B * self.n_bands, 1, -1)
    frame = int(self.segment * self.sample_rate)
    stride = int(frame * (1 - self.overlap))
    gt = _unfold(bands_ref, frame, stride).squeeze(1).contiguous().view(-1, 1, frame)
    est = _unfold(bands_out, frame, stride).squeeze(1).contiguous().view(-1, 1, frame)
    l_noise = basic_loudness(est - gt, sample_rate=self.sample_rate)  # watermark
    l_ref = basic_loudness(gt, sample_rate=self.sample_rate)  # ground truth
    l_ratio = (l_noise - l_ref).view(-1, B)
    loss = torch.nn.functional.softmax(l_ratio / self.temperature, dim=0) * l_ratio
    return loss.mean()

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class TLoudnessRatio (sample_rate: int = 16000, segment: float = 0.5, overlap: float = 0.5)
Expand source code
class TLoudnessRatio(nn.Module):
    """TSNR loss.

    Input should be [B, C, T], output is scalar.

    Args:
        sample_rate (int): Sample rate.
        segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
            entire audio only.
        overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
    """
    def __init__(
        self,
        sample_rate: int = 16000,
        segment: float = 0.5,
        overlap: float = 0.5,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.segment = segment
        self.overlap = overlap
        self.loudness = torchaudio.transforms.Loudness(sample_rate)

    def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
        B, C, T = ref_sig.shape
        assert ref_sig.shape == out_sig.shape
        assert C == 1

        frame = int(self.segment * self.sample_rate)
        stride = int(frame * (1 - self.overlap))
        gt = _unfold(ref_sig, frame, stride).view(-1, 1, frame)
        est = _unfold(out_sig, frame, stride).view(-1, 1, frame)
        l_noise = self.loudness(gt - est)  # watermark
        l_ref = self.loudness(gt)  # ground truth
        l_ratio = (l_noise - l_ref).view(-1, B)
        loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio
        return loss.sum()

TSNR loss.

Input should be [B, C, T], output is scalar.

Args

sample_rate : int
Sample rate.
segment : float or None
Evaluate on chunks of that many seconds. If None, evaluate on entire audio only.
overlap : float
Overlap between chunks, i.e. 0.5 = 50 % overlap.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) ‑> torch.Tensor
Expand source code
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
    B, C, T = ref_sig.shape
    assert ref_sig.shape == out_sig.shape
    assert C == 1

    frame = int(self.segment * self.sample_rate)
    stride = int(frame * (1 - self.overlap))
    gt = _unfold(ref_sig, frame, stride).view(-1, 1, frame)
    est = _unfold(out_sig, frame, stride).view(-1, 1, frame)
    l_noise = self.loudness(gt - est)  # watermark
    l_ref = self.loudness(gt)  # ground truth
    l_ratio = (l_noise - l_ref).view(-1, B)
    loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio
    return loss.sum()

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.