Module audiocraft.modules.chroma

Classes

class ChromaExtractor (sample_rate: int,
n_chroma: int = 12,
radix2_exp: int = 12,
nfft: int | None = None,
winlen: int | None = None,
winhop: int | None = None,
argmax: bool = False,
norm: float = inf)
Expand source code
class ChromaExtractor(nn.Module):
    """Chroma extraction and quantization.

    Args:
        sample_rate (int): Sample rate for the chroma extraction.
        n_chroma (int): Number of chroma bins for the chroma extraction.
        radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
        nfft (int, optional): Number of FFT.
        winlen (int, optional): Window length.
        winhop (int, optional): Window hop size.
        argmax (bool, optional): Whether to use argmax. Defaults to False.
        norm (float, optional): Norm for chroma normalization. Defaults to inf.
    """
    def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
                 winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
                 norm: float = torch.inf):
        super().__init__()
        self.winlen = winlen or 2 ** radix2_exp
        self.nfft = nfft or self.winlen
        self.winhop = winhop or (self.winlen // 4)
        self.sample_rate = sample_rate
        self.n_chroma = n_chroma
        self.norm = norm
        self.argmax = argmax
        self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
                                                                       n_chroma=self.n_chroma)), persistent=False)
        self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
                                                      hop_length=self.winhop, power=2, center=True,
                                                      pad=0, normalized=True)

    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        T = wav.shape[-1]
        # in case we are getting a wav that was dropped out (nullified)
        # from the conditioner, make sure wav length is no less that nfft
        if T < self.nfft:
            pad = self.nfft - T
            r = 0 if pad % 2 == 0 else 1
            wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
            assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"

        spec = self.spec(wav).squeeze(1)
        raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
        norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
        norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')

        if self.argmax:
            idx = norm_chroma.argmax(-1, keepdim=True)
            norm_chroma[:] = 0
            norm_chroma.scatter_(dim=-1, index=idx, value=1)

        return norm_chroma

Chroma extraction and quantization.

Args

sample_rate : int
Sample rate for the chroma extraction.
n_chroma : int
Number of chroma bins for the chroma extraction.
radix2_exp : int
Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
nfft : int, optional
Number of FFT.
winlen : int, optional
Window length.
winhop : int, optional
Window hop size.
argmax : bool, optional
Whether to use argmax. Defaults to False.
norm : float, optional
Norm for chroma normalization. Defaults to inf.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def forward(self, wav: torch.Tensor) ‑> torch.Tensor
Expand source code
def forward(self, wav: torch.Tensor) -> torch.Tensor:
    T = wav.shape[-1]
    # in case we are getting a wav that was dropped out (nullified)
    # from the conditioner, make sure wav length is no less that nfft
    if T < self.nfft:
        pad = self.nfft - T
        r = 0 if pad % 2 == 0 else 1
        wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
        assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"

    spec = self.spec(wav).squeeze(1)
    raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
    norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
    norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')

    if self.argmax:
        idx = norm_chroma.argmax(-1, keepdim=True)
        norm_chroma[:] = 0
        norm_chroma.scatter_(dim=-1, index=idx, value=1)

    return norm_chroma

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.