Module audiocraft.models.magnet
Main model for using MAGNeT. This will combine all the required components and provide easy access to the generation API.
Classes
class MAGNeT (**kwargs)
-
Expand source code
class MAGNeT(BaseGenModel): """MAGNeT main model with convenient generation API. Args: See MusicGen class. """ def __init__(self, **kwargs): super().__init__(**kwargs) # MAGNeT operates over a fixed sequence length defined in it's config. self.duration = self.lm.cfg.dataset.segment_duration self.set_generation_params() @staticmethod def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None): """Return pretrained model, we provide six models: - facebook/magnet-small-10secs (300M), text to music, 10-second audio samples. # see: https://huggingface.co/facebook/magnet-small-10secs - facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples. # see: https://huggingface.co/facebook/magnet-medium-10secs - facebook/magnet-small-30secs (300M), text to music, 30-second audio samples. # see: https://huggingface.co/facebook/magnet-small-30secs - facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples. # see: https://huggingface.co/facebook/magnet-medium-30secs - facebook/audio-magnet-small (300M), text to sound-effect (10-second samples). # see: https://huggingface.co/facebook/audio-magnet-small - facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples). # see: https://huggingface.co/facebook/audio-magnet-medium """ if device is None: if torch.cuda.device_count(): device = 'cuda' else: device = 'cpu' compression_model = load_compression_model(name, device=device) lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device) if 'self_wav' in lm.condition_provider.conditioners: lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm} return MAGNeT(**kwargs) def set_generation_params(self, use_sampling: bool = True, top_k: int = 0, top_p: float = 0.9, temperature: float = 3.0, max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0, decoding_steps: tp.List[int] = [20, 10, 10, 10], span_arrangement: str = 'nonoverlap'): """Set the generation parameters for MAGNeT. Args: use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. top_k (int, optional): top_k used for sampling. Defaults to 0. top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9. temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0. max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0. min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0. decoding_steps (list of n_q ints, optional): The number of iterative decoding steps, for each of the n_q RVQ codebooks. span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') in the masking scheme. """ self.generation_params = { 'use_sampling': use_sampling, 'temp': temperature, 'top_k': top_k, 'top_p': top_p, 'max_cfg_coef': max_cfg_coef, 'min_cfg_coef': min_cfg_coef, 'decoding_steps': [int(s) for s in decoding_steps], 'span_arrangement': span_arrangement }
MAGNeT main model with convenient generation API.
Args
See MusicGen class.
Ancestors
- BaseGenModel
- abc.ABC
Static methods
def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None)
-
Expand source code
@staticmethod def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None): """Return pretrained model, we provide six models: - facebook/magnet-small-10secs (300M), text to music, 10-second audio samples. # see: https://huggingface.co/facebook/magnet-small-10secs - facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples. # see: https://huggingface.co/facebook/magnet-medium-10secs - facebook/magnet-small-30secs (300M), text to music, 30-second audio samples. # see: https://huggingface.co/facebook/magnet-small-30secs - facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples. # see: https://huggingface.co/facebook/magnet-medium-30secs - facebook/audio-magnet-small (300M), text to sound-effect (10-second samples). # see: https://huggingface.co/facebook/audio-magnet-small - facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples). # see: https://huggingface.co/facebook/audio-magnet-medium """ if device is None: if torch.cuda.device_count(): device = 'cuda' else: device = 'cpu' compression_model = load_compression_model(name, device=device) lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device) if 'self_wav' in lm.condition_provider.conditioners: lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm} return MAGNeT(**kwargs)
Return pretrained model, we provide six models: - facebook/magnet-small-10secs (300M), text to music, 10-second audio samples. # see: https://huggingface.co/facebook/magnet-small-10secs - facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples. # see: https://huggingface.co/facebook/magnet-medium-10secs - facebook/magnet-small-30secs (300M), text to music, 30-second audio samples. # see: https://huggingface.co/facebook/magnet-small-30secs - facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples. # see: https://huggingface.co/facebook/magnet-medium-30secs - facebook/audio-magnet-small (300M), text to sound-effect (10-second samples). # see: https://huggingface.co/facebook/audio-magnet-small - facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples). # see: https://huggingface.co/facebook/audio-magnet-medium
Methods
def set_generation_params(self,
use_sampling: bool = True,
top_k: int = 0,
top_p: float = 0.9,
temperature: float = 3.0,
max_cfg_coef: float = 10.0,
min_cfg_coef: float = 1.0,
decoding_steps: List[int] = [20, 10, 10, 10],
span_arrangement: str = 'nonoverlap')-
Expand source code
def set_generation_params(self, use_sampling: bool = True, top_k: int = 0, top_p: float = 0.9, temperature: float = 3.0, max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0, decoding_steps: tp.List[int] = [20, 10, 10, 10], span_arrangement: str = 'nonoverlap'): """Set the generation parameters for MAGNeT. Args: use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. top_k (int, optional): top_k used for sampling. Defaults to 0. top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9. temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0. max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0. min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0. decoding_steps (list of n_q ints, optional): The number of iterative decoding steps, for each of the n_q RVQ codebooks. span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') in the masking scheme. """ self.generation_params = { 'use_sampling': use_sampling, 'temp': temperature, 'top_k': top_k, 'top_p': top_p, 'max_cfg_coef': max_cfg_coef, 'min_cfg_coef': min_cfg_coef, 'decoding_steps': [int(s) for s in decoding_steps], 'span_arrangement': span_arrangement }
Set the generation parameters for MAGNeT.
Args
use_sampling
:bool
, optional- Use sampling if True, else do argmax decoding. Defaults to True.
top_k
:int
, optional- top_k used for sampling. Defaults to 0.
top_p
:float
, optional- top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
temperature
:float
, optional- Initial softmax temperature parameter. Defaults to 3.0.
max_cfg_coef
:float
, optional- Coefficient used for classifier free guidance. Defaults to 10.0.
min_cfg_coef
:float
, optional- End coefficient of classifier free guidance annealing. Defaults to 1.0.
decoding_steps
:list
ofn_q ints
, optional- The number of iterative decoding steps, for each of the n_q RVQ codebooks.
span_arrangement
:str
, optional- Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') in the masking scheme.
Inherited members