Module audiocraft.models.builders

All the functions to build the relevant models and modules from the Hydra config.

Functions

def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.dictconfig.DictConfig) ‑> CodebooksPatternProvider
Expand source code
def get_codebooks_pattern_provider(
    n_q: int, cfg: omegaconf.DictConfig
) -> CodebooksPatternProvider:
    """Instantiate a codebooks pattern provider object."""
    pattern_providers = {
        "parallel": ParallelPatternProvider,
        "delay": DelayedPatternProvider,
        "unroll": UnrolledPatternProvider,
        "coarse_first": CoarseFirstPattern,
        "musiclm": MusicLMPattern,
    }
    name = cfg.modeling
    kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
    klass = pattern_providers[name]
    return klass(n_q, **kwargs)

Instantiate a codebooks pattern provider object.

def get_compression_model(cfg: omegaconf.dictconfig.DictConfig) ‑> CompressionModel
Expand source code
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
    """Instantiate a compression model."""
    if cfg.compression_model == "encodec":
        kwargs = dict_from_config(getattr(cfg, "encodec"))
        encoder_name = kwargs.pop("autoencoder")
        quantizer_name = kwargs.pop("quantizer")
        encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
        quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
        frame_rate = kwargs["sample_rate"] // encoder.hop_length
        renormalize = kwargs.pop("renormalize", False)
        # deprecated params
        kwargs.pop("renorm", None)
        return EncodecModel(
            encoder,
            decoder,
            quantizer,
            frame_rate=frame_rate,
            renormalize=renormalize,
            **kwargs,
        ).to(cfg.device)
    else:
        raise KeyError(f"Unexpected compression model {cfg.compression_model}")

Instantiate a compression model.

def get_condition_fuser(cfg: omegaconf.dictconfig.DictConfig) ‑> ConditionFuser
Expand source code
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
    """Instantiate a condition fuser object."""
    fuser_cfg = getattr(cfg, "fuser")
    fuser_methods = ["sum", "cross", "prepend", "ignore", "input_interpolate"]
    fuse2cond = {k: fuser_cfg[k] for k in fuser_methods if k in fuser_cfg}
    kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
    fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
    return fuser

Instantiate a condition fuser object.

def get_conditioner_provider(output_dim: int, cfg: omegaconf.dictconfig.DictConfig) ‑> ConditioningProvider
Expand source code
def get_conditioner_provider(
    output_dim: int, cfg: omegaconf.DictConfig
) -> ConditioningProvider:
    """Instantiate a conditioning model."""
    device = cfg.device
    duration = cfg.dataset.segment_duration
    cfg = getattr(cfg, "conditioners")
    dict_cfg = {} if cfg is None else dict_from_config(cfg)
    conditioners: tp.Dict[str, BaseConditioner] = {}
    condition_provider_args = dict_cfg.pop("args", {})
    condition_provider_args.pop("merge_text_conditions_p", None)
    condition_provider_args.pop("drop_desc_p", None)

    for cond, cond_cfg in dict_cfg.items():
        model_type = cond_cfg["model"]
        model_args = cond_cfg[model_type]
        if model_type == "t5":
            conditioners[str(cond)] = T5Conditioner(
                output_dim=output_dim, device=device, **model_args
            )
        elif model_type == "lut":
            conditioners[str(cond)] = LUTConditioner(
                output_dim=output_dim, **model_args
            )
        elif model_type == "chroma_stem":
            conditioners[str(cond)] = ChromaStemConditioner(
                output_dim=output_dim, duration=duration, device=device, **model_args
            )
        elif model_type in {"chords_emb", "drum_latents", "melody"}:
            conditioners_classes = {"chords_emb": ChordsEmbConditioner,
                                    "drum_latents": DrumsConditioner,
                                    "melody": MelodyConditioner}
            conditioner_class = conditioners_classes[model_type]
            conditioners[str(cond)] = conditioner_class(device=device, **model_args)
        elif model_type == "clap":
            conditioners[str(cond)] = CLAPEmbeddingConditioner(
                output_dim=output_dim, device=device, **model_args
            )
        elif model_type == 'style':
            conditioners[str(cond)] = StyleConditioner(
                output_dim=output_dim,
                device=device,
                **model_args
            )
        else:
            raise ValueError(f"Unrecognized conditioning model: {model_type}")
    conditioner = ConditioningProvider(
        conditioners, device=device, **condition_provider_args
    )
    return conditioner

Instantiate a conditioning model.

def get_debug_compression_model(device='cpu', sample_rate: int = 32000)
Expand source code
def get_debug_compression_model(device="cpu", sample_rate: int = 32000):
    """Instantiate a debug compression model to be used for unit tests."""
    assert sample_rate in [
        16000,
        32000,
    ], "unsupported sample rate for debug compression model"
    model_ratios = {
        16000: [10, 8, 8],  # 25 Hz at 16kHz
        32000: [10, 8, 16],  # 25 Hz at 32kHz
    }
    ratios: tp.List[int] = model_ratios[sample_rate]
    frame_rate = 25
    seanet_kwargs: dict = {
        "n_filters": 4,
        "n_residual_layers": 1,
        "dimension": 32,
        "ratios": ratios,
    }
    encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
    decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
    quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
    init_x = torch.randn(8, 32, 128)
    quantizer(init_x, 1)  # initialize kmeans etc.
    compression_model = EncodecModel(
        encoder,
        decoder,
        quantizer,
        frame_rate=frame_rate,
        sample_rate=sample_rate,
        channels=1,
    ).to(device)
    return compression_model.eval()

Instantiate a debug compression model to be used for unit tests.

def get_debug_lm_model(device='cpu')
Expand source code
def get_debug_lm_model(device="cpu"):
    """Instantiate a debug LM to be used for unit tests."""
    pattern = DelayedPatternProvider(n_q=4)
    dim = 16
    providers = {
        "description": LUTConditioner(
            n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"
        ),
    }
    condition_provider = ConditioningProvider(providers)
    fuser = ConditionFuser(
        {"cross": ["description"], "prepend": [], "sum": [], "input_interpolate": []}
    )
    lm = LMModel(
        pattern,
        condition_provider,
        fuser,
        n_q=4,
        card=400,
        dim=dim,
        num_heads=4,
        custom=True,
        num_layers=2,
        cross_attention=True,
        causal=True,
    )
    return lm.to(device).eval()

Instantiate a debug LM to be used for unit tests.

def get_diffusion_model(cfg: omegaconf.dictconfig.DictConfig)
Expand source code
def get_diffusion_model(cfg: omegaconf.DictConfig):
    # TODO Find a way to infer the channels from dset
    channels = cfg.channels
    num_steps = cfg.schedule.num_steps
    return DiffusionUnet(chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.dictconfig.DictConfig)
Expand source code
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
    if encoder_name == "seanet":
        kwargs = dict_from_config(getattr(cfg, "seanet"))
        encoder_override_kwargs = kwargs.pop("encoder")
        decoder_override_kwargs = kwargs.pop("decoder")
        encoder_kwargs = {**kwargs, **encoder_override_kwargs}
        decoder_kwargs = {**kwargs, **decoder_override_kwargs}
        encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
        decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
        return encoder, decoder
    else:
        raise KeyError(f"Unexpected compression model {cfg.compression_model}")
def get_jasco_model(cfg: omegaconf.dictconfig.DictConfig,
compression_model: CompressionModel | None = None) ‑> FlowMatchingModel
Expand source code
def get_jasco_model(cfg: omegaconf.DictConfig,
                    compression_model: tp.Optional[CompressionModel] = None) -> FlowMatchingModel:
    kwargs = dict_from_config(getattr(cfg, "transformer_lm"))
    attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout"))
    cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance"))
    cfg_prob = cls_free_guidance["training_dropout"]
    cfg_coef = cls_free_guidance["inference_coef"]
    fuser = get_condition_fuser(cfg)
    condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
    if JascoCondConst.DRM.value in condition_provider.conditioners:  # use self_wav for drums
        assert compression_model is not None

        # use compression model for drums conditioning
        condition_provider.conditioners.self_wav.compression_model = compression_model
        condition_provider.conditioners.self_wav.compression_model.requires_grad_(False)

    # downcast to jasco conditioning provider
    seq_len = cfg.compression_model_framerate * cfg.dataset.segment_duration
    chords_card = cfg.conditioners.chords.chords_emb.card if JascoCondConst.CRD.value in cfg.conditioners else -1
    condition_provider = JascoConditioningProvider(device=condition_provider.device,
                                                   conditioners=condition_provider.conditioners,
                                                   chords_card=chords_card,
                                                   sequence_length=seq_len)

    if len(fuser.fuse2cond["cross"]) > 0:  # enforce cross-att programmatically
        kwargs["cross_attention"] = True

    kwargs.pop("n_q", None)
    kwargs.pop("card", None)

    return FlowMatchingModel(
        condition_provider=condition_provider,
        fuser=fuser,
        cfg_dropout=cfg_prob,
        cfg_coef=cfg_coef,
        attribute_dropout=attribute_dropout,
        dtype=getattr(torch, cfg.dtype),
        device=cfg.device,
        **kwargs,
    ).to(cfg.device)
def get_lm_model(cfg: omegaconf.dictconfig.DictConfig) ‑> LMModel
Expand source code
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
    """Instantiate a transformer LM."""
    if cfg.lm_model in ["transformer_lm", "transformer_lm_magnet"]:
        kwargs = dict_from_config(getattr(cfg, "transformer_lm"))
        n_q = kwargs["n_q"]
        q_modeling = kwargs.pop("q_modeling", None)
        codebooks_pattern_cfg = getattr(cfg, "codebooks_pattern")
        attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout"))
        cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance"))
        cfg_prob, cfg_coef = (
            cls_free_guidance["training_dropout"],
            cls_free_guidance["inference_coef"],
        )
        fuser = get_condition_fuser(cfg)
        condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
        if len(fuser.fuse2cond["cross"]) > 0:  # enforce cross-att programmatically
            kwargs["cross_attention"] = True
        if codebooks_pattern_cfg.modeling is None:
            assert (
                q_modeling is not None
            ), "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
            codebooks_pattern_cfg = omegaconf.OmegaConf.create(
                {"modeling": q_modeling, "delay": {"delays": list(range(n_q))}}
            )

        pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
        lm_class = MagnetLMModel if cfg.lm_model == "transformer_lm_magnet" else LMModel
        return lm_class(
            pattern_provider=pattern_provider,
            condition_provider=condition_provider,
            fuser=fuser,
            cfg_dropout=cfg_prob,
            cfg_coef=cfg_coef,
            attribute_dropout=attribute_dropout,
            dtype=getattr(torch, cfg.dtype),
            device=cfg.device,
            **kwargs,
        ).to(cfg.device)
    else:
        raise KeyError(f"Unexpected LM model {cfg.lm_model}")

Instantiate a transformer LM.

def get_processor(cfg, sample_rate: int = 24000)
Expand source code
def get_processor(cfg, sample_rate: int = 24000):
    sample_processor = SampleProcessor()
    if cfg.use:
        kw = dict(cfg)
        kw.pop("use")
        kw.pop("name")
        if cfg.name == "multi_band_processor":
            sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
    return sample_processor
def get_quantizer(quantizer: str, cfg: omegaconf.dictconfig.DictConfig, dimension: int) ‑> BaseQuantizer
Expand source code
def get_quantizer(
    quantizer: str, cfg: omegaconf.DictConfig, dimension: int
) -> qt.BaseQuantizer:
    klass = {"no_quant": qt.DummyQuantizer, "rvq": qt.ResidualVectorQuantizer}[
        quantizer
    ]
    kwargs = dict_from_config(getattr(cfg, quantizer))
    if quantizer != "no_quant":
        kwargs["dimension"] = dimension
    return klass(**kwargs)
def get_watermark_model(cfg: omegaconf.dictconfig.DictConfig) ‑> WMModel
Expand source code
def get_watermark_model(cfg: omegaconf.DictConfig) -> WMModel:
    """Build a WMModel based by audioseal. This requires audioseal to be installed"""
    import audioseal

    from .watermark import AudioSeal

    # Builder encoder and decoder directly using audiocraft API to avoid cyclic import
    assert hasattr(
        cfg, "seanet"
    ), "Missing required `seanet` parameters in AudioSeal config"
    encoder, decoder = get_encodec_autoencoder("seanet", cfg)

    # Build message processor
    kwargs = (
        dict_from_config(getattr(cfg, "audioseal")) if hasattr(cfg, "audioseal") else {}
    )
    nbits = kwargs.get("nbits", 0)
    hidden_size = getattr(cfg.seanet, "dimension", 128)
    msg_processor = audioseal.MsgProcessor(nbits, hidden_size=hidden_size)

    # Build detector using audioseal API
    def _get_audioseal_detector():
        # We don't need encoder and decoder params from seanet, remove them
        seanet_cfg = dict_from_config(cfg.seanet)
        seanet_cfg.pop("encoder")
        seanet_cfg.pop("decoder")
        detector_cfg = dict_from_config(cfg.detector)

        typed_seanet_cfg = audioseal.builder.SEANetConfig(**seanet_cfg)
        typed_detector_cfg = audioseal.builder.DetectorConfig(**detector_cfg)
        _cfg = audioseal.builder.AudioSealDetectorConfig(
            nbits=nbits, seanet=typed_seanet_cfg, detector=typed_detector_cfg
        )
        return audioseal.builder.create_detector(_cfg)

    detector = _get_audioseal_detector()
    generator = audioseal.AudioSealWM(
        encoder=encoder, decoder=decoder, msg_processor=msg_processor
    )
    model = AudioSeal(generator=generator, detector=detector, nbits=nbits)

    device = torch.device(getattr(cfg, "device", "cpu"))
    dtype = getattr(torch, getattr(cfg, "dtype", "float32"))
    return model.to(device=device, dtype=dtype)

Build a WMModel based by audioseal. This requires audioseal to be installed

def get_wrapped_compression_model(compression_model: CompressionModel,
cfg: omegaconf.dictconfig.DictConfig) ‑> CompressionModel
Expand source code
def get_wrapped_compression_model(
    compression_model: CompressionModel, cfg: omegaconf.DictConfig
) -> CompressionModel:
    if hasattr(cfg, "interleave_stereo_codebooks"):
        if cfg.interleave_stereo_codebooks.use:
            kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
            kwargs.pop("use")
            compression_model = InterleaveStereoCompressionModel(
                compression_model, **kwargs
            )
    if hasattr(cfg, "compression_model_n_q"):
        if cfg.compression_model_n_q is not None:
            compression_model.set_num_codebooks(cfg.compression_model_n_q)
    return compression_model