Module audiocraft.models.magnet

Main model for using MAGNeT. This will combine all the required components and provide easy access to the generation API.

Classes

class MAGNeT (**kwargs)
Expand source code
class MAGNeT(BaseGenModel):
    """MAGNeT main model with convenient generation API.
    Args:
       See MusicGen class.
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # MAGNeT operates over a fixed sequence length defined in it's config.
        self.duration = self.lm.cfg.dataset.segment_duration
        self.set_generation_params()

    @staticmethod
    def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None):
        """Return pretrained model, we provide six models:
        - facebook/magnet-small-10secs (300M), text to music, 10-second audio samples.
          # see: https://huggingface.co/facebook/magnet-small-10secs
        - facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples.
          # see: https://huggingface.co/facebook/magnet-medium-10secs
        - facebook/magnet-small-30secs (300M), text to music, 30-second audio samples.
          # see: https://huggingface.co/facebook/magnet-small-30secs
        - facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples.
          # see: https://huggingface.co/facebook/magnet-medium-30secs
        - facebook/audio-magnet-small (300M), text to sound-effect (10-second samples).
          # see: https://huggingface.co/facebook/audio-magnet-small
        - facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples).
          # see: https://huggingface.co/facebook/audio-magnet-medium
        """
        if device is None:
            if torch.cuda.device_count():
                device = 'cuda'
            else:
                device = 'cpu'

        compression_model = load_compression_model(name, device=device)
        lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device)

        if 'self_wav' in lm.condition_provider.conditioners:
            lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True

        kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm}
        return MAGNeT(**kwargs)

    def set_generation_params(self, use_sampling: bool = True, top_k: int = 0,
                              top_p: float = 0.9, temperature: float = 3.0,
                              max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0,
                              decoding_steps: tp.List[int] = [20, 10, 10, 10],
                              span_arrangement: str = 'nonoverlap'):
        """Set the generation parameters for MAGNeT.

        Args:
            use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
            top_k (int, optional): top_k used for sampling. Defaults to 0.
            top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
            temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0.
            max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0.
            min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0.
            decoding_steps (list of n_q ints, optional): The number of iterative decoding steps,
                                                         for each of the n_q RVQ codebooks.
            span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap')
                                              or overlapping spans ('stride1') in the masking scheme.
        """
        self.generation_params = {
            'use_sampling': use_sampling,
            'temp': temperature,
            'top_k': top_k,
            'top_p': top_p,
            'max_cfg_coef': max_cfg_coef,
            'min_cfg_coef': min_cfg_coef,
            'decoding_steps': [int(s) for s in decoding_steps],
            'span_arrangement': span_arrangement
        }

MAGNeT main model with convenient generation API.

Args

See MusicGen class.

Ancestors

Static methods

def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None)
Expand source code
@staticmethod
def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None):
    """Return pretrained model, we provide six models:
    - facebook/magnet-small-10secs (300M), text to music, 10-second audio samples.
      # see: https://huggingface.co/facebook/magnet-small-10secs
    - facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples.
      # see: https://huggingface.co/facebook/magnet-medium-10secs
    - facebook/magnet-small-30secs (300M), text to music, 30-second audio samples.
      # see: https://huggingface.co/facebook/magnet-small-30secs
    - facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples.
      # see: https://huggingface.co/facebook/magnet-medium-30secs
    - facebook/audio-magnet-small (300M), text to sound-effect (10-second samples).
      # see: https://huggingface.co/facebook/audio-magnet-small
    - facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples).
      # see: https://huggingface.co/facebook/audio-magnet-medium
    """
    if device is None:
        if torch.cuda.device_count():
            device = 'cuda'
        else:
            device = 'cpu'

    compression_model = load_compression_model(name, device=device)
    lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device)

    if 'self_wav' in lm.condition_provider.conditioners:
        lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True

    kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm}
    return MAGNeT(**kwargs)

Return pretrained model, we provide six models: - facebook/magnet-small-10secs (300M), text to music, 10-second audio samples. # see: https://huggingface.co/facebook/magnet-small-10secs - facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples. # see: https://huggingface.co/facebook/magnet-medium-10secs - facebook/magnet-small-30secs (300M), text to music, 30-second audio samples. # see: https://huggingface.co/facebook/magnet-small-30secs - facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples. # see: https://huggingface.co/facebook/magnet-medium-30secs - facebook/audio-magnet-small (300M), text to sound-effect (10-second samples). # see: https://huggingface.co/facebook/audio-magnet-small - facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples). # see: https://huggingface.co/facebook/audio-magnet-medium

Methods

def set_generation_params(self,
use_sampling: bool = True,
top_k: int = 0,
top_p: float = 0.9,
temperature: float = 3.0,
max_cfg_coef: float = 10.0,
min_cfg_coef: float = 1.0,
decoding_steps: List[int] = [20, 10, 10, 10],
span_arrangement: str = 'nonoverlap')
Expand source code
def set_generation_params(self, use_sampling: bool = True, top_k: int = 0,
                          top_p: float = 0.9, temperature: float = 3.0,
                          max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0,
                          decoding_steps: tp.List[int] = [20, 10, 10, 10],
                          span_arrangement: str = 'nonoverlap'):
    """Set the generation parameters for MAGNeT.

    Args:
        use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
        top_k (int, optional): top_k used for sampling. Defaults to 0.
        top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
        temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0.
        max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0.
        min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0.
        decoding_steps (list of n_q ints, optional): The number of iterative decoding steps,
                                                     for each of the n_q RVQ codebooks.
        span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap')
                                          or overlapping spans ('stride1') in the masking scheme.
    """
    self.generation_params = {
        'use_sampling': use_sampling,
        'temp': temperature,
        'top_k': top_k,
        'top_p': top_p,
        'max_cfg_coef': max_cfg_coef,
        'min_cfg_coef': min_cfg_coef,
        'decoding_steps': [int(s) for s in decoding_steps],
        'span_arrangement': span_arrangement
    }

Set the generation parameters for MAGNeT.

Args

use_sampling : bool, optional
Use sampling if True, else do argmax decoding. Defaults to True.
top_k : int, optional
top_k used for sampling. Defaults to 0.
top_p : float, optional
top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
temperature : float, optional
Initial softmax temperature parameter. Defaults to 3.0.
max_cfg_coef : float, optional
Coefficient used for classifier free guidance. Defaults to 10.0.
min_cfg_coef : float, optional
End coefficient of classifier free guidance annealing. Defaults to 1.0.
decoding_steps : list of n_q ints, optional
The number of iterative decoding steps, for each of the n_q RVQ codebooks.
span_arrangement : str, optional
Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') in the masking scheme.

Inherited members