Module audiocraft.models.multibanddiffusion

Multi Band Diffusion models as described in "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" (paper link).

Classes

class DiffusionProcess (model: DiffusionUnet,
noise_schedule: NoiseSchedule)
Expand source code
class DiffusionProcess:
    """Sampling for a diffusion Model.

    Args:
        model (DiffusionUnet): Diffusion U-Net model.
        noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
    """
    def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
        self.model = model
        self.schedule = noise_schedule

    def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
                 step_list: tp.Optional[tp.List[int]] = None):
        """Perform one diffusion process to generate one of the bands.

        Args:
            condition (torch.Tensor): The embeddings from the compression model.
            initial_noise (torch.Tensor): The initial noise to start the process.
        """
        return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
                                                 condition=condition)

Sampling for a diffusion Model.

Args

model : DiffusionUnet
Diffusion U-Net model.
noise_schedule : NoiseSchedule
Noise schedule for diffusion process.

Methods

def generate(self,
condition: torch.Tensor,
initial_noise: torch.Tensor,
step_list: List[int] | None = None)
Expand source code
def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
             step_list: tp.Optional[tp.List[int]] = None):
    """Perform one diffusion process to generate one of the bands.

    Args:
        condition (torch.Tensor): The embeddings from the compression model.
        initial_noise (torch.Tensor): The initial noise to start the process.
    """
    return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
                                             condition=condition)

Perform one diffusion process to generate one of the bands.

Args

condition : torch.Tensor
The embeddings from the compression model.
initial_noise : torch.Tensor
The initial noise to start the process.
class MultiBandDiffusion (DPs: List[DiffusionProcess],
codec_model: CompressionModel)
Expand source code
class MultiBandDiffusion:
    """Sample from multiple diffusion models.

    Args:
        DPs (list of DiffusionProcess): Diffusion processes.
        codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
    """
    def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
        self.DPs = DPs
        self.codec_model = codec_model
        self.device = next(self.codec_model.parameters()).device

    @property
    def sample_rate(self) -> int:
        return self.codec_model.sample_rate

    @staticmethod
    def get_mbd_musicgen(device=None):
        """Load our diffusion models trained for MusicGen."""
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        path = 'facebook/multiband-diffusion'
        filename = 'mbd_musicgen_32khz.th'
        name = 'facebook/musicgen-small'
        codec_model = load_compression_model(name, device=device)
        models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
        DPs = []
        for i in range(len(models)):
            schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
            DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
        return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)

    @staticmethod
    def get_mbd_24khz(bw: float = 3.0,
                      device: tp.Optional[tp.Union[torch.device, str]] = None,
                      n_q: tp.Optional[int] = None):
        """Get the pretrained Models for MultibandDiffusion.

        Args:
            bw (float): Bandwidth of the compression model.
            device (torch.device or str, optional): Device on which the models are loaded.
            n_q (int, optional): Number of quantizers to use within the compression model.
        """
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
        if n_q is not None:
            assert n_q in [2, 4, 8]
            assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
                f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
        n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
        codec_model = CompressionSolver.model_from_checkpoint(
            '//pretrained/facebook/encodec_24khz', device=device)
        codec_model.set_num_codebooks(n_q)
        codec_model = codec_model.to(device)
        path = 'facebook/multiband-diffusion'
        filename = f'mbd_comp_{n_q}.pt'
        models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
        DPs = []
        for i in range(len(models)):
            schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
            DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
        return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)

    @torch.no_grad()
    def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
        """Get the conditioning (i.e. latent representations of the compression model) from a waveform.
        Args:
            wav (torch.Tensor): The audio that we want to extract the conditioning from.
            sample_rate (int): Sample rate of the audio."""
        if sample_rate != self.sample_rate:
            wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
        codes, scale = self.codec_model.encode(wav)
        assert scale is None, "Scaled compression models not supported."
        emb = self.get_emb(codes)
        return emb

    @torch.no_grad()
    def get_emb(self, codes: torch.Tensor):
        """Get latent representation from the discrete codes.
        Args:
            codes (torch.Tensor): Discrete tokens."""
        emb = self.codec_model.decode_latent(codes)
        return emb

    def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
                 step_list: tp.Optional[tp.List[int]] = None):
        """Generate waveform audio from the latent embeddings of the compression model.
        Args:
            emb (torch.Tensor): Conditioning embeddings
            size (None, torch.Size): Size of the output
                if None this is computed from the typical upsampling of the model.
            step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step.
        """
        if size is None:
            upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
            size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
        assert size[0] == emb.size(0)
        out = torch.zeros(size).to(self.device)
        for DP in self.DPs:
            out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
        return out

    def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
        """Match the eq to the encodec output by matching the standard deviation of some frequency bands.
        Args:
            wav (torch.Tensor): Audio to equalize.
            ref (torch.Tensor): Reference audio from which we match the spectrogram.
            n_bands (int): Number of bands of the eq.
            strictness (float): How strict the matching. 0 is no matching, 1 is exact matching.
        """
        split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
        bands = split(wav)
        bands_ref = split(ref)
        out = torch.zeros_like(ref)
        for i in range(n_bands):
            out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
        return out

    def regenerate(self, wav: torch.Tensor, sample_rate: int):
        """Regenerate a waveform through compression and diffusion regeneration.
        Args:
            wav (torch.Tensor): Original 'ground truth' audio.
            sample_rate (int): Sample rate of the input (and output) wav.
        """
        if sample_rate != self.codec_model.sample_rate:
            wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
        emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
        size = wav.size()
        out = self.generate(emb, size=size)
        if sample_rate != self.codec_model.sample_rate:
            out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
        return out

    def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
        """Generate Waveform audio with diffusion from the discrete codes.
        Args:
            tokens (torch.Tensor): Discrete codes.
            n_bands (int): Bands for the eq matching.
        """
        wav_encodec = self.codec_model.decode(tokens)
        condition = self.get_emb(tokens)
        wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
        return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)

Sample from multiple diffusion models.

Args

DPs : list of DiffusionProcess
Diffusion processes.
codec_model : CompressionModel
Underlying compression model used to obtain discrete tokens.

Static methods

def get_mbd_24khz(bw: float = 3.0,
device: str | torch.device | None = None,
n_q: int | None = None)
Expand source code
@staticmethod
def get_mbd_24khz(bw: float = 3.0,
                  device: tp.Optional[tp.Union[torch.device, str]] = None,
                  n_q: tp.Optional[int] = None):
    """Get the pretrained Models for MultibandDiffusion.

    Args:
        bw (float): Bandwidth of the compression model.
        device (torch.device or str, optional): Device on which the models are loaded.
        n_q (int, optional): Number of quantizers to use within the compression model.
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
    if n_q is not None:
        assert n_q in [2, 4, 8]
        assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
            f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
    n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
    codec_model = CompressionSolver.model_from_checkpoint(
        '//pretrained/facebook/encodec_24khz', device=device)
    codec_model.set_num_codebooks(n_q)
    codec_model = codec_model.to(device)
    path = 'facebook/multiband-diffusion'
    filename = f'mbd_comp_{n_q}.pt'
    models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
    DPs = []
    for i in range(len(models)):
        schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
        DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
    return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)

Get the pretrained Models for MultibandDiffusion.

Args

bw : float
Bandwidth of the compression model.
device : torch.device or str, optional
Device on which the models are loaded.
n_q : int, optional
Number of quantizers to use within the compression model.
def get_mbd_musicgen(device=None)
Expand source code
@staticmethod
def get_mbd_musicgen(device=None):
    """Load our diffusion models trained for MusicGen."""
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    path = 'facebook/multiband-diffusion'
    filename = 'mbd_musicgen_32khz.th'
    name = 'facebook/musicgen-small'
    codec_model = load_compression_model(name, device=device)
    models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
    DPs = []
    for i in range(len(models)):
        schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
        DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
    return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)

Load our diffusion models trained for MusicGen.

Instance variables

prop sample_rate : int
Expand source code
@property
def sample_rate(self) -> int:
    return self.codec_model.sample_rate

Methods

def generate(self,
emb: torch.Tensor,
size: torch.Size | None = None,
step_list: List[int] | None = None)
Expand source code
def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
             step_list: tp.Optional[tp.List[int]] = None):
    """Generate waveform audio from the latent embeddings of the compression model.
    Args:
        emb (torch.Tensor): Conditioning embeddings
        size (None, torch.Size): Size of the output
            if None this is computed from the typical upsampling of the model.
        step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step.
    """
    if size is None:
        upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
        size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
    assert size[0] == emb.size(0)
    out = torch.zeros(size).to(self.device)
    for DP in self.DPs:
        out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
    return out

Generate waveform audio from the latent embeddings of the compression model.

Args

emb : torch.Tensor
Conditioning embeddings
size : None, torch.Size
Size of the output if None this is computed from the typical upsampling of the model.
step_list : list[int], optional
list of Markov chain steps, defaults to 50 linearly spaced step.
def get_condition(self, wav: torch.Tensor, sample_rate: int) ‑> torch.Tensor
Expand source code
@torch.no_grad()
def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
    """Get the conditioning (i.e. latent representations of the compression model) from a waveform.
    Args:
        wav (torch.Tensor): The audio that we want to extract the conditioning from.
        sample_rate (int): Sample rate of the audio."""
    if sample_rate != self.sample_rate:
        wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
    codes, scale = self.codec_model.encode(wav)
    assert scale is None, "Scaled compression models not supported."
    emb = self.get_emb(codes)
    return emb

Get the conditioning (i.e. latent representations of the compression model) from a waveform.

Args

wav : torch.Tensor
The audio that we want to extract the conditioning from.
sample_rate : int
Sample rate of the audio.
def get_emb(self, codes: torch.Tensor)
Expand source code
@torch.no_grad()
def get_emb(self, codes: torch.Tensor):
    """Get latent representation from the discrete codes.
    Args:
        codes (torch.Tensor): Discrete tokens."""
    emb = self.codec_model.decode_latent(codes)
    return emb

Get latent representation from the discrete codes.

Args

codes : torch.Tensor
Discrete tokens.
def re_eq(self,
wav: torch.Tensor,
ref: torch.Tensor,
n_bands: int = 32,
strictness: float = 1)
Expand source code
def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
    """Match the eq to the encodec output by matching the standard deviation of some frequency bands.
    Args:
        wav (torch.Tensor): Audio to equalize.
        ref (torch.Tensor): Reference audio from which we match the spectrogram.
        n_bands (int): Number of bands of the eq.
        strictness (float): How strict the matching. 0 is no matching, 1 is exact matching.
    """
    split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
    bands = split(wav)
    bands_ref = split(ref)
    out = torch.zeros_like(ref)
    for i in range(n_bands):
        out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
    return out

Match the eq to the encodec output by matching the standard deviation of some frequency bands.

Args

wav : torch.Tensor
Audio to equalize.
ref : torch.Tensor
Reference audio from which we match the spectrogram.
n_bands : int
Number of bands of the eq.
strictness : float
How strict the matching. 0 is no matching, 1 is exact matching.
def regenerate(self, wav: torch.Tensor, sample_rate: int)
Expand source code
def regenerate(self, wav: torch.Tensor, sample_rate: int):
    """Regenerate a waveform through compression and diffusion regeneration.
    Args:
        wav (torch.Tensor): Original 'ground truth' audio.
        sample_rate (int): Sample rate of the input (and output) wav.
    """
    if sample_rate != self.codec_model.sample_rate:
        wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
    emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
    size = wav.size()
    out = self.generate(emb, size=size)
    if sample_rate != self.codec_model.sample_rate:
        out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
    return out

Regenerate a waveform through compression and diffusion regeneration.

Args

wav : torch.Tensor
Original 'ground truth' audio.
sample_rate : int
Sample rate of the input (and output) wav.
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32)
Expand source code
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
    """Generate Waveform audio with diffusion from the discrete codes.
    Args:
        tokens (torch.Tensor): Discrete codes.
        n_bands (int): Bands for the eq matching.
    """
    wav_encodec = self.codec_model.decode(tokens)
    condition = self.get_emb(tokens)
    wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
    return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)

Generate Waveform audio with diffusion from the discrete codes.

Args

tokens : torch.Tensor
Discrete codes.
n_bands : int
Bands for the eq matching.