Module audiocraft.modules.chroma
Classes
class ChromaExtractor (sample_rate: int,
n_chroma: int = 12,
radix2_exp: int = 12,
nfft: int | None = None,
winlen: int | None = None,
winhop: int | None = None,
argmax: bool = False,
norm: float = inf)-
Expand source code
class ChromaExtractor(nn.Module): """Chroma extraction and quantization. Args: sample_rate (int): Sample rate for the chroma extraction. n_chroma (int): Number of chroma bins for the chroma extraction. radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). nfft (int, optional): Number of FFT. winlen (int, optional): Window length. winhop (int, optional): Window hop size. argmax (bool, optional): Whether to use argmax. Defaults to False. norm (float, optional): Norm for chroma normalization. Defaults to inf. """ def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, norm: float = torch.inf): super().__init__() self.winlen = winlen or 2 ** radix2_exp self.nfft = nfft or self.winlen self.winhop = winhop or (self.winlen // 4) self.sample_rate = sample_rate self.n_chroma = n_chroma self.norm = norm self.argmax = argmax self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, n_chroma=self.n_chroma)), persistent=False) self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, hop_length=self.winhop, power=2, center=True, pad=0, normalized=True) def forward(self, wav: torch.Tensor) -> torch.Tensor: T = wav.shape[-1] # in case we are getting a wav that was dropped out (nullified) # from the conditioner, make sure wav length is no less that nfft if T < self.nfft: pad = self.nfft - T r = 0 if pad % 2 == 0 else 1 wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" spec = self.spec(wav).squeeze(1) raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') if self.argmax: idx = norm_chroma.argmax(-1, keepdim=True) norm_chroma[:] = 0 norm_chroma.scatter_(dim=-1, index=idx, value=1) return norm_chroma
Chroma extraction and quantization.
Args
sample_rate
:int
- Sample rate for the chroma extraction.
n_chroma
:int
- Number of chroma bins for the chroma extraction.
radix2_exp
:int
- Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
nfft
:int
, optional- Number of FFT.
winlen
:int
, optional- Window length.
winhop
:int
, optional- Window hop size.
argmax
:bool
, optional- Whether to use argmax. Defaults to False.
norm
:float
, optional- Norm for chroma normalization. Defaults to inf.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Methods
def forward(self, wav: torch.Tensor) ‑> torch.Tensor
-
Expand source code
def forward(self, wav: torch.Tensor) -> torch.Tensor: T = wav.shape[-1] # in case we are getting a wav that was dropped out (nullified) # from the conditioner, make sure wav length is no less that nfft if T < self.nfft: pad = self.nfft - T r = 0 if pad % 2 == 0 else 1 wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" spec = self.spec(wav).squeeze(1) raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') if self.argmax: idx = norm_chroma.argmax(-1, keepdim=True) norm_chroma[:] = 0 norm_chroma.scatter_(dim=-1, index=idx, value=1) return norm_chroma
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.