Module audiocraft.solvers.musicgen

Classes

class MusicGenSolver (cfg: omegaconf.dictconfig.DictConfig)
Expand source code
class MusicGenSolver(base.StandardSolver):
    """Solver for MusicGen training task.

    Used in: https://arxiv.org/abs/2306.05284
    """
    DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC

    def __init__(self, cfg: omegaconf.DictConfig):
        super().__init__(cfg)
        # easier access to sampling parameters
        self.generation_params = {
            'use_sampling': self.cfg.generate.lm.use_sampling,
            'temp': self.cfg.generate.lm.temp,
            'top_k': self.cfg.generate.lm.top_k,
            'top_p': self.cfg.generate.lm.top_p,
        }
        self._best_metric_name: tp.Optional[str] = 'ce'

        self._cached_batch_writer = None
        self._cached_batch_loader = None
        if cfg.cache.path:
            if cfg.cache.write:
                self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path))
                if self.cfg.cache.write_num_shards:
                    self.logger.warning("Multiple shard cache, best_metric_name will be set to None.")
                    self._best_metric_name = None
            else:
                self._cached_batch_loader = CachedBatchLoader(
                    Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers,
                    min_length=self.cfg.optim.updates_per_epoch or 1)
                self.dataloaders['original_train'] = self.dataloaders['train']
                self.dataloaders['train'] = self._cached_batch_loader  # type: ignore

    @staticmethod
    def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
                                 device: tp.Optional[str] = None, autocast: bool = True,
                                 batch_size: tp.Optional[int] = None,
                                 override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
                                 **kwargs):
        """Mostly a convenience function around magma.train.get_solver_from_sig,
        populating all the proper param, deactivating EMA, FSDP, loading the best state,
        basically all you need to get a solver ready to "play" with in single GPU mode
        and with minimal memory overhead.

        Args:
            sig (str): signature to load.
            dtype (str or None): potential dtype, as a string, i.e. 'float16'.
            device (str or None): potential device, as a string, i.e. 'cuda'.
            override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
        """
        from audiocraft import train
        our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
        our_override_cfg['autocast'] = autocast
        if dtype is not None:
            our_override_cfg['dtype'] = dtype
        if device is not None:
            our_override_cfg['device'] = device
        if batch_size is not None:
            our_override_cfg['dataset'] = {'batch_size': batch_size}
        if override_cfg is None:
            override_cfg = {}
        override_cfg = omegaconf.OmegaConf.merge(
            omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg))  # type: ignore
        solver = train.get_solver_from_sig(
            sig, override_cfg=override_cfg,
            load_best=True, disable_fsdp=True,
            ignore_state_keys=['optimizer', 'ema'], **kwargs)
        solver.model.eval()
        return solver

    def get_formatter(self, stage_name: str) -> flashy.Formatter:
        return flashy.Formatter({
            'lr': '.2E',
            'ce': '.3f',
            'ppl': '.3f',
            'grad_norm': '.3E',
        }, exclude_keys=['ce_q*', 'ppl_q*'])

    @property
    def best_metric_name(self) -> tp.Optional[str]:
        return self._best_metric_name

    def initialize_optimization(self) -> None:
        if self.cfg.fsdp.use:
            assert not self.cfg.autocast, "Cannot use autocast with fsdp"
            self.model = self.wrap_with_fsdp(self.model)
        self.register_ema('model')
        # initialize optimization
        self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
        self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
        self.register_stateful('model', 'optimizer', 'lr_scheduler')
        self.register_best_state('model')
        self.autocast_dtype = {
            'float16': torch.float16, 'bfloat16': torch.bfloat16
        }[self.cfg.autocast_dtype]
        self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
        if self.cfg.fsdp.use:
            need_scaler = self.cfg.fsdp.param_dtype == 'float16'
        else:
            need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
        if need_scaler:
            if self.cfg.fsdp.use:
                from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
                self.scaler = ShardedGradScaler()  # type: ignore
            else:
                self.scaler = torch.cuda.amp.GradScaler()
            self.register_stateful('scaler')

    def build_model(self) -> None:
        """Instantiate models and optimizer."""
        # we can potentially not use all quantizers with which the EnCodec model was trained
        # (e.g. we trained the model with quantizers dropout)
        self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
            self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
        assert self.compression_model.sample_rate == self.cfg.sample_rate, (
            f"Compression model sample rate is {self.compression_model.sample_rate} but "
            f"Solver sample rate is {self.cfg.sample_rate}."
            )
        # ensure we have matching configuration between LM and compression model
        assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
            "Cardinalities of the LM and compression model don't match: ",
            f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
            f"compression model cardinality is {self.compression_model.cardinality}"
        )
        assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
            "Numbers of codebooks of the LM and compression models don't match: ",
            f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
            f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
        )
        self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
                         self.compression_model.num_codebooks, self.compression_model.cardinality,
                         self.compression_model.frame_rate)
        # instantiate LM model
        self.model: tp.Union[models.LMModel, models.FlowMatchingModel] = models.builders.get_lm_model(
            self.cfg).to(self.device)

        # initialize optimization
        self.initialize_optimization()

    def build_dataloaders(self) -> None:
        """Instantiate audio dataloaders for each stage."""
        self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)

    def show(self) -> None:
        """Show the compression model and LM model."""
        self.logger.info("Compression model:")
        self.log_model_summary(self.compression_model)
        self.logger.info("LM model:")
        self.log_model_summary(self.model)

    def load_state_dict(self, state: dict) -> None:
        if 'condition_provider' in state:
            model_state = state['model']
            condition_provider_state = state.pop('condition_provider')
            prefix = 'condition_provider.'
            for key, value in condition_provider_state.items():
                key = prefix + key
                assert key not in model_state
                model_state[key] = value
        if 'compression_model' in state:
            # We used to store the `compression_model` state in the checkpoint, however
            # this is in general not needed, as the compression model should always be readable
            # from the original `cfg.compression_model_checkpoint` location.
            compression_model_state = state.pop('compression_model')
            before_hash = model_hash(self.compression_model)
            self.compression_model.load_state_dict(compression_model_state)
            after_hash = model_hash(self.compression_model)
            if before_hash != after_hash:
                raise RuntimeError(
                    "The compression model state inside the checkpoint is different"
                    " from the one obtained from compression_model_checkpoint..."
                    "We do not support altering the compression model inside the LM "
                    "checkpoint as parts of the code, in particular for running eval post-training "
                    "will use the compression_model_checkpoint as the source of truth.")

        super().load_state_dict(state)

    def load_from_pretrained(self, name: str):
        # TODO: support native HF versions of MusicGen.
        lm_pkg = models.loaders.load_lm_model_ckpt(name)
        state: dict = {
            'best_state': {
                'model': lm_pkg['best_state'],
            },
        }
        return state

    def _compute_cross_entropy(
        self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
    ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
        """Compute cross entropy between multi-codebook targets and model's logits.
        The cross entropy is computed per codebook to provide codebook-level cross entropy.
        Valid timesteps for each of the codebook are pulled from the mask, where invalid
        timesteps are set to 0.

        Args:
            logits (torch.Tensor): Model's logits of shape [B, K, T, card].
            targets (torch.Tensor): Target codes, of shape [B, K, T].
            mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
        Returns:
            ce (torch.Tensor): Cross entropy averaged over the codebooks
            ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
        """
        B, K, T = targets.shape
        assert logits.shape[:-1] == targets.shape
        assert mask.shape == targets.shape
        ce = torch.zeros([], device=targets.device)
        ce_per_codebook: tp.List[torch.Tensor] = []
        for k in range(K):
            logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
            targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
            mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
            ce_targets = targets_k[mask_k]
            ce_logits = logits_k[mask_k]
            q_ce = F.cross_entropy(ce_logits, ce_targets)
            ce += q_ce
            ce_per_codebook.append(q_ce.detach())
        # average cross entropy across codebooks
        ce = ce / K
        return ce, ce_per_codebook

    def _get_audio_tokens(self, audio: torch.Tensor):
        with torch.no_grad():
            audio_tokens, scale = self.compression_model.encode(audio)
            assert scale is None, "Scaled compression model not supported with LM."
            return audio_tokens

    def _prepare_tokens_and_attributes(
        self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        check_synchronization_points: bool = False
    ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
        """Prepare input batchs for language model training.

        Args:
            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
                and corresponding metadata as SegmentWithAttributes (with B items).
            check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
        Returns:
            Condition tensors (dict[str, any]): Preprocessed condition attributes.
            Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
                with B the batch size, K the number of codebooks, T_s the token timesteps.
            Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
        """
        if self.model.training:
            warnings.warn(
                "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
                "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
                "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
                "Really sorry about that.")
        if self._cached_batch_loader is None or self.current_stage != "train":
            audio, infos = batch
            audio = audio.to(self.device)
            audio_tokens = None
            assert audio.size(0) == len(infos), (
                f"Mismatch between number of items in audio batch ({audio.size(0)})",
                f" and in metadata ({len(infos)})"
            )
        else:
            audio = None
            # In that case the batch will be a tuple coming from the _cached_batch_writer bit below.
            infos, = batch  # type: ignore
            assert all([isinstance(info, AudioInfo) for info in infos])
            assert all([info.audio_tokens is not None for info in infos])  # type: ignore
            audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device)  # type: ignore
            audio_tokens = audio_tokens.long()
            for info in infos:
                if isinstance(info, MusicInfo):
                    # Careful here, if you want to use this condition_wav (e.b. chroma conditioning),
                    # then you must be using the chroma cache! otherwise the code will try
                    # to use this segment and fail (by that I mean you will see NaN everywhere).
                    info.self_wav = WavCondition(
                        torch.full([1, info.channels, info.total_frames], float('NaN')),
                        length=torch.tensor([info.n_frames]),
                        sample_rate=[info.sample_rate],
                        path=[info.meta.path],
                        seek_time=[info.seek_time])
                    dataset = get_dataset_from_loader(self.dataloaders['original_train'])
                    assert isinstance(dataset, MusicDataset), type(dataset)
                    if dataset.paraphraser is not None and info.description is not None:
                        # Hackingly reapplying paraphraser when using cache.
                        info.description = dataset.paraphraser.sample_paraphrase(
                            info.meta.path, info.description)
        # prepare attributes
        attributes = [info.to_condition_attributes() for info in infos]
        attributes = self.model.cfg_dropout(attributes)
        attributes = self.model.att_dropout(attributes)
        tokenized = self.model.condition_provider.tokenize(attributes)

        # Now we should be synchronization free.
        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        if audio_tokens is None:
            audio_tokens = self._get_audio_tokens(audio)

        with self.autocast:
            condition_tensors = self.model.condition_provider(tokenized)

        # create a padding mask to hold valid vs invalid positions
        padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
        # replace encodec tokens from padded audio with special_token_id
        if self.cfg.tokens.padding_with_special_token:
            audio_tokens = audio_tokens.clone()
            padding_mask = padding_mask.clone()
            token_sample_rate = self.compression_model.frame_rate
            B, K, T_s = audio_tokens.shape
            for i in range(B):
                n_samples = infos[i].n_frames
                audio_sample_rate = infos[i].sample_rate
                # take the last token generated from actual audio frames (non-padded audio)
                valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate)
                audio_tokens[i, :, valid_tokens:] = self.model.special_token_id
                padding_mask[i, :, valid_tokens:] = 0

        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")

        if self._cached_batch_writer is not None and self.current_stage == 'train':
            assert self._cached_batch_loader is None
            assert audio_tokens is not None
            for info, one_audio_tokens in zip(infos, audio_tokens):
                assert isinstance(info, AudioInfo)
                if isinstance(info, MusicInfo):
                    assert not info.joint_embed, "joint_embed and cache not supported yet."
                    info.self_wav = None
                assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item()
                info.audio_tokens = one_audio_tokens.short().cpu()
            self._cached_batch_writer.save(infos)

        return condition_tensors, audio_tokens, padding_mask

    def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
        """Perform one training or valid step on a given batch."""
        check_synchronization_points = idx == 1 and self.device == 'cuda'

        condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
            batch, check_synchronization_points)

        self.deadlock_detect.update('tokens_and_conditions')

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode('warn')

        with self.autocast:
            style_mask = None
            if hasattr(self.model.condition_provider.conditioners, 'self_wav'):
                if isinstance(self.model.condition_provider.conditioners.self_wav, StyleConditioner):
                    style_mask = self.model.condition_provider.conditioners.self_wav.mask

            model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors)  # type: ignore
            logits = model_output.logits
            if style_mask is not None:
                mask = padding_mask & model_output.mask & style_mask
            else:
                mask = padding_mask & model_output.mask
            ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
            loss = ce
        self.deadlock_detect.update('loss')

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode('default')

        if self.is_training:
            metrics['lr'] = self.optimizer.param_groups[0]['lr']
            if self.scaler is not None:
                loss = self.scaler.scale(loss)
            self.deadlock_detect.update('scale')
            if self.cfg.fsdp.use:
                loss.backward()
                flashy.distrib.average_tensors(self.model.buffers())
            elif self.cfg.optim.eager_sync:
                with flashy.distrib.eager_sync_model(self.model):
                    loss.backward()
            else:
                # this should always be slower but can be useful
                # for weird use cases like multiple backwards.
                loss.backward()
                flashy.distrib.sync_model(self.model)
            self.deadlock_detect.update('backward')

            if self.scaler is not None:
                self.scaler.unscale_(self.optimizer)
            if self.cfg.optim.max_norm:
                if self.cfg.fsdp.use:
                    metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm)  # type: ignore
                else:
                    metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.cfg.optim.max_norm
                    )
            if self.scaler is None:
                self.optimizer.step()
            else:
                self.scaler.step(self.optimizer)
                self.scaler.update()
            if self.lr_scheduler:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            self.deadlock_detect.update('optim')
            if self.scaler is not None:
                scale = self.scaler.get_scale()
                metrics['grad_scale'] = scale
            if not loss.isfinite().all():
                raise RuntimeError("Model probably diverged.")

        metrics['ce'] = ce
        metrics['ppl'] = torch.exp(ce)
        for k, ce_q in enumerate(ce_per_codebook):
            metrics[f'ce_q{k + 1}'] = ce_q
            metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)

        return metrics

    @torch.no_grad()
    def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
                          gen_duration: float, prompt_duration: tp.Optional[float] = None,
                          remove_text_conditioning: bool = False,
                          ) -> dict:
        """Run generate step on a batch of optional audio tensor and corresponding attributes.

        Args:
            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
            use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
            gen_duration (float): Target audio duration for the generation.
            prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
        Returns:
            gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
                and the prompt along with additional information.
        """
        bench_start = time.time()
        audio, meta = batch
        assert audio.size(0) == len(meta), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(meta)})"
        )
        # prepare attributes
        attributes = [x.to_condition_attributes() for x in meta]
        if remove_text_conditioning:
            attributes = _drop_description_condition(attributes)

        # prepare audio prompt
        if prompt_duration is None:
            prompt_audio = None
        else:
            assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
            prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
            prompt_audio = audio[..., :prompt_audio_frames]

        # get audio tokens from compression model
        if prompt_audio is None or prompt_audio.nelement() == 0:
            num_samples = len(attributes)
            prompt_tokens = None
        else:
            num_samples = None
            prompt_audio = prompt_audio.to(self.device)
            prompt_tokens, scale = self.compression_model.encode(prompt_audio)
            assert scale is None, "Compression model in MusicGen should not require rescaling."

        # generate by sampling from the LM
        with self.autocast:
            total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
            gen_tokens = self.model.generate(
                prompt_tokens, attributes, max_gen_len=total_gen_len,
                num_samples=num_samples, **self.generation_params)

        # generate audio from tokens
        assert gen_tokens.dim() == 3
        gen_audio = self.compression_model.decode(gen_tokens, None)

        bench_end = time.time()
        gen_outputs = {
            'rtf': (bench_end - bench_start) / gen_duration,
            'ref_audio': audio,
            'gen_audio': gen_audio,
            'gen_tokens': gen_tokens,
            'prompt_audio': prompt_audio,
            'prompt_tokens': prompt_tokens,
        }
        return gen_outputs

    def generate_audio(self) -> dict:
        """Audio generation stage."""
        generate_stage_name = f'{self.current_stage}'
        sample_manager = SampleManager(self.xp)
        self.logger.info(f"Generating samples in {sample_manager.base_folder}")
        loader = self.dataloaders['generate']
        updates = len(loader)
        lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)

        dataset = get_dataset_from_loader(loader)
        dataset_duration = dataset.segment_duration
        assert dataset_duration is not None
        assert isinstance(dataset, AudioDataset)
        target_duration = self.cfg.generate.lm.gen_duration
        prompt_duration = self.cfg.generate.lm.prompt_duration
        if target_duration is None:
            target_duration = dataset_duration
        if prompt_duration is None:
            prompt_duration = dataset_duration / 4
        assert prompt_duration < dataset_duration, (
            f"Specified prompt duration ({prompt_duration}s) is longer",
            f" than reference audio duration ({dataset_duration}s)"
        )

        def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
            hydrated_conditions = []
            for sample in [x.to_condition_attributes() for x in meta]:
                cond_dict = {}
                for cond_type in sample.__annotations__.keys():
                    for cond_key, cond_val in getattr(sample, cond_type).items():
                        if cond_key not in self.model.condition_provider.conditioners.keys():
                            continue
                        if is_jsonable(cond_val):
                            cond_dict[cond_key] = cond_val
                        elif isinstance(cond_val, WavCondition):
                            cond_dict[cond_key] = cond_val.path
                        elif isinstance(cond_val, JointEmbedCondition):
                            cond_dict[cond_key] = cond_val.text  # only support text at inference for now
                        else:
                            # if we reached this point, it is not clear how to log the condition
                            # so we just log the type.
                            cond_dict[cond_key] = str(type(cond_val))
                            continue
                hydrated_conditions.append(cond_dict)
            return hydrated_conditions

        metrics: dict = {}
        average = flashy.averager()
        for batch in lp:
            audio, meta = batch
            # metadata for sample manager
            hydrated_conditions = get_hydrated_conditions(meta)
            sample_generation_params = {
                **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
                **self.generation_params
            }
            if self.cfg.generate.lm.unprompted_samples:
                if self.cfg.generate.lm.gen_gt_samples:
                    # get the ground truth instead of generation
                    self.logger.warn(
                        "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
                    gen_unprompted_audio = audio
                    rtf = 1.
                else:
                    gen_unprompted_outputs = self.run_generate_step(
                        batch, gen_duration=target_duration, prompt_duration=None)
                    gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
                    rtf = gen_unprompted_outputs['rtf']
                sample_manager.add_samples(
                    gen_unprompted_audio, self.epoch, hydrated_conditions,
                    ground_truth_wavs=audio, generation_args=sample_generation_params)

            if self.cfg.generate.lm.prompted_samples:
                gen_outputs = self.run_generate_step(
                    batch, gen_duration=target_duration, prompt_duration=prompt_duration)
                gen_audio = gen_outputs['gen_audio'].cpu()
                prompt_audio = gen_outputs['prompt_audio'].cpu()
                sample_manager.add_samples(
                    gen_audio, self.epoch, hydrated_conditions,
                    prompt_wavs=prompt_audio, ground_truth_wavs=audio,
                    generation_args=sample_generation_params)
            if self.cfg.generate.lm.no_text_conditioning:
                gen_outputs = self.run_generate_step(
                    batch, gen_duration=target_duration, prompt_duration=None,
                    remove_text_conditioning=self.cfg.generate.lm.no_text_conditioning)
                gen_audio = gen_outputs['gen_audio'].cpu()
                rtf = gen_outputs['rtf']
                # Here, the prompt is the original audio provided for the style conditioning
                prompt_audio = gen_outputs['ref_audio'].cpu()
                sample_manager.add_samples(
                    gen_audio, self.epoch, hydrated_conditions,
                    prompt_wavs=prompt_audio, ground_truth_wavs=audio,
                    generation_args=sample_generation_params)

            metrics['rtf'] = rtf
            metrics = average(metrics)

        flashy.distrib.barrier()
        return metrics

    def generate(self) -> dict:
        """Generate stage."""
        self.model.eval()
        with torch.no_grad():
            return self.generate_audio()

    def run_epoch(self):
        if self.cfg.cache.write:
            if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard:
                return
        super().run_epoch()

    def train(self):
        """Train stage.
        """
        if self._cached_batch_writer is not None:
            self._cached_batch_writer.start_epoch(self.epoch)
        if self._cached_batch_loader is None:
            dataset = get_dataset_from_loader(self.dataloaders['train'])
            assert isinstance(dataset, AudioDataset)
            dataset.current_epoch = self.epoch
        else:
            self._cached_batch_loader.start_epoch(self.epoch)
        return super().train()

    def evaluate_audio_generation(self) -> dict:
        """Evaluate audio generation with off-the-shelf metrics."""
        evaluate_stage_name = f'{self.current_stage}_generation'
        # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
        fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
        kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
        text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
        chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
        should_run_eval = False
        eval_chroma_wavs: tp.Optional[torch.Tensor] = None
        if self.cfg.evaluate.metrics.fad:
            fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
            should_run_eval = True
        if self.cfg.evaluate.metrics.kld:
            kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
            should_run_eval = True
        if self.cfg.evaluate.metrics.text_consistency:
            text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
            should_run_eval = True
        if self.cfg.evaluate.metrics.chroma_cosine:
            chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
            # if we have predefind wavs for chroma we should purge them for computing the cosine metric
            has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
                                          self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
            if has_predefined_eval_chromas:
                warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
                                       'Resetting eval chromas to None for evaluation.')
                eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs  # type: ignore
                self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None)  # type: ignore
            should_run_eval = True

        def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
            audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
            compressed_audio = self.compression_model.decode(audio_tokens, scale)
            return compressed_audio[..., :audio.shape[-1]]

        metrics: dict = {}
        if should_run_eval:
            loader = self.dataloaders['evaluate']
            updates = len(loader)
            lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
            average = flashy.averager()
            dataset = get_dataset_from_loader(loader)
            assert isinstance(dataset, AudioDataset)
            self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")

            for idx, batch in enumerate(lp):
                audio, meta = batch
                assert all([self.cfg.sample_rate == m.sample_rate for m in meta])

                target_duration = audio.shape[-1] / self.cfg.sample_rate
                if self.cfg.evaluate.fixed_generation_duration:
                    target_duration = self.cfg.evaluate.fixed_generation_duration

                gen_outputs = self.run_generate_step(
                    batch, gen_duration=target_duration,
                    remove_text_conditioning=self.cfg.evaluate.get('remove_text_conditioning', False)
                )
                y_pred = gen_outputs['gen_audio'].detach()
                y_pred = y_pred[..., :audio.shape[-1]]

                normalize_kwargs = dict(self.cfg.generate.audio)
                normalize_kwargs.pop('format', None)
                y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
                y = audio.cpu()  # should already be on CPU but just in case
                sizes = torch.tensor([m.n_frames for m in meta])  # actual sizes without padding
                sample_rates = torch.tensor([m.sample_rate for m in meta])  # sample rates for audio samples
                audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]

                if fad is not None:
                    if self.cfg.metrics.fad.use_gt:
                        y_pred = get_compressed_audio(y).cpu()
                    fad.update(y_pred, y, sizes, sample_rates, audio_stems)
                if kldiv is not None:
                    if self.cfg.metrics.kld.use_gt:
                        y_pred = get_compressed_audio(y).cpu()
                    kldiv.update(y_pred, y, sizes, sample_rates)
                if text_consistency is not None:
                    texts = [m.description for m in meta]
                    if self.cfg.metrics.text_consistency.use_gt:
                        y_pred = y
                    text_consistency.update(y_pred, texts, sizes, sample_rates)
                if chroma_cosine is not None:
                    if self.cfg.metrics.chroma_cosine.use_gt:
                        y_pred = get_compressed_audio(y).cpu()
                    chroma_cosine.update(y_pred, y, sizes, sample_rates)
                    # restore chroma conditioner's eval chroma wavs
                    if eval_chroma_wavs is not None:
                        self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)

            flashy.distrib.barrier()
            if fad is not None:
                metrics['fad'] = fad.compute()
            if kldiv is not None:
                kld_metrics = kldiv.compute()
                metrics.update(kld_metrics)
            if text_consistency is not None:
                metrics['text_consistency'] = text_consistency.compute()
            if chroma_cosine is not None:
                metrics['chroma_cosine'] = chroma_cosine.compute()
            metrics = average(metrics)
            metrics = flashy.distrib.average_metrics(metrics, len(loader))

        return metrics

    def evaluate(self) -> dict:
        """Evaluate stage."""
        self.model.eval()
        with torch.no_grad():
            metrics: dict = {}
            if self.cfg.evaluate.metrics.base:
                metrics.update(self.common_train_valid('evaluate'))
            gen_metrics = self.evaluate_audio_generation()
            return {**metrics, **gen_metrics}

Solver for MusicGen training task.

Used in: https://arxiv.org/abs/2306.05284

Ancestors

Subclasses

Class variables

var DATASET_TYPEDatasetType

Static methods

def get_eval_solver_from_sig(sig: str,
dtype: str | None = None,
device: str | None = None,
autocast: bool = True,
batch_size: int | None = None,
override_cfg: dict | omegaconf.dictconfig.DictConfig | None = None,
**kwargs)
Expand source code
@staticmethod
def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
                             device: tp.Optional[str] = None, autocast: bool = True,
                             batch_size: tp.Optional[int] = None,
                             override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
                             **kwargs):
    """Mostly a convenience function around magma.train.get_solver_from_sig,
    populating all the proper param, deactivating EMA, FSDP, loading the best state,
    basically all you need to get a solver ready to "play" with in single GPU mode
    and with minimal memory overhead.

    Args:
        sig (str): signature to load.
        dtype (str or None): potential dtype, as a string, i.e. 'float16'.
        device (str or None): potential device, as a string, i.e. 'cuda'.
        override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
    """
    from audiocraft import train
    our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
    our_override_cfg['autocast'] = autocast
    if dtype is not None:
        our_override_cfg['dtype'] = dtype
    if device is not None:
        our_override_cfg['device'] = device
    if batch_size is not None:
        our_override_cfg['dataset'] = {'batch_size': batch_size}
    if override_cfg is None:
        override_cfg = {}
    override_cfg = omegaconf.OmegaConf.merge(
        omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg))  # type: ignore
    solver = train.get_solver_from_sig(
        sig, override_cfg=override_cfg,
        load_best=True, disable_fsdp=True,
        ignore_state_keys=['optimizer', 'ema'], **kwargs)
    solver.model.eval()
    return solver

Mostly a convenience function around magma.train.get_solver_from_sig, populating all the proper param, deactivating EMA, FSDP, loading the best state, basically all you need to get a solver ready to "play" with in single GPU mode and with minimal memory overhead.

Args

sig : str
signature to load.
dtype : str or None
potential dtype, as a string, i.e. 'float16'.
device : str or None
potential device, as a string, i.e. 'cuda'.
override_cfg : dict or omegaconf.DictConfig or None
potential device, as a string, i.e. 'cuda'.

Methods

def build_dataloaders(self) ‑> None
Expand source code
def build_dataloaders(self) -> None:
    """Instantiate audio dataloaders for each stage."""
    self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)

Instantiate audio dataloaders for each stage.

def build_model(self) ‑> None
Expand source code
def build_model(self) -> None:
    """Instantiate models and optimizer."""
    # we can potentially not use all quantizers with which the EnCodec model was trained
    # (e.g. we trained the model with quantizers dropout)
    self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
        self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
    assert self.compression_model.sample_rate == self.cfg.sample_rate, (
        f"Compression model sample rate is {self.compression_model.sample_rate} but "
        f"Solver sample rate is {self.cfg.sample_rate}."
        )
    # ensure we have matching configuration between LM and compression model
    assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
        "Cardinalities of the LM and compression model don't match: ",
        f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
        f"compression model cardinality is {self.compression_model.cardinality}"
    )
    assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
        "Numbers of codebooks of the LM and compression models don't match: ",
        f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
        f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
    )
    self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
                     self.compression_model.num_codebooks, self.compression_model.cardinality,
                     self.compression_model.frame_rate)
    # instantiate LM model
    self.model: tp.Union[models.LMModel, models.FlowMatchingModel] = models.builders.get_lm_model(
        self.cfg).to(self.device)

    # initialize optimization
    self.initialize_optimization()

Instantiate models and optimizer.

def evaluate_audio_generation(self) ‑> dict
Expand source code
def evaluate_audio_generation(self) -> dict:
    """Evaluate audio generation with off-the-shelf metrics."""
    evaluate_stage_name = f'{self.current_stage}_generation'
    # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
    fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
    kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
    text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
    chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
    should_run_eval = False
    eval_chroma_wavs: tp.Optional[torch.Tensor] = None
    if self.cfg.evaluate.metrics.fad:
        fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
        should_run_eval = True
    if self.cfg.evaluate.metrics.kld:
        kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
        should_run_eval = True
    if self.cfg.evaluate.metrics.text_consistency:
        text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
        should_run_eval = True
    if self.cfg.evaluate.metrics.chroma_cosine:
        chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
        # if we have predefind wavs for chroma we should purge them for computing the cosine metric
        has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
                                      self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
        if has_predefined_eval_chromas:
            warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
                                   'Resetting eval chromas to None for evaluation.')
            eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs  # type: ignore
            self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None)  # type: ignore
        should_run_eval = True

    def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
        audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
        compressed_audio = self.compression_model.decode(audio_tokens, scale)
        return compressed_audio[..., :audio.shape[-1]]

    metrics: dict = {}
    if should_run_eval:
        loader = self.dataloaders['evaluate']
        updates = len(loader)
        lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
        average = flashy.averager()
        dataset = get_dataset_from_loader(loader)
        assert isinstance(dataset, AudioDataset)
        self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")

        for idx, batch in enumerate(lp):
            audio, meta = batch
            assert all([self.cfg.sample_rate == m.sample_rate for m in meta])

            target_duration = audio.shape[-1] / self.cfg.sample_rate
            if self.cfg.evaluate.fixed_generation_duration:
                target_duration = self.cfg.evaluate.fixed_generation_duration

            gen_outputs = self.run_generate_step(
                batch, gen_duration=target_duration,
                remove_text_conditioning=self.cfg.evaluate.get('remove_text_conditioning', False)
            )
            y_pred = gen_outputs['gen_audio'].detach()
            y_pred = y_pred[..., :audio.shape[-1]]

            normalize_kwargs = dict(self.cfg.generate.audio)
            normalize_kwargs.pop('format', None)
            y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
            y = audio.cpu()  # should already be on CPU but just in case
            sizes = torch.tensor([m.n_frames for m in meta])  # actual sizes without padding
            sample_rates = torch.tensor([m.sample_rate for m in meta])  # sample rates for audio samples
            audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]

            if fad is not None:
                if self.cfg.metrics.fad.use_gt:
                    y_pred = get_compressed_audio(y).cpu()
                fad.update(y_pred, y, sizes, sample_rates, audio_stems)
            if kldiv is not None:
                if self.cfg.metrics.kld.use_gt:
                    y_pred = get_compressed_audio(y).cpu()
                kldiv.update(y_pred, y, sizes, sample_rates)
            if text_consistency is not None:
                texts = [m.description for m in meta]
                if self.cfg.metrics.text_consistency.use_gt:
                    y_pred = y
                text_consistency.update(y_pred, texts, sizes, sample_rates)
            if chroma_cosine is not None:
                if self.cfg.metrics.chroma_cosine.use_gt:
                    y_pred = get_compressed_audio(y).cpu()
                chroma_cosine.update(y_pred, y, sizes, sample_rates)
                # restore chroma conditioner's eval chroma wavs
                if eval_chroma_wavs is not None:
                    self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)

        flashy.distrib.barrier()
        if fad is not None:
            metrics['fad'] = fad.compute()
        if kldiv is not None:
            kld_metrics = kldiv.compute()
            metrics.update(kld_metrics)
        if text_consistency is not None:
            metrics['text_consistency'] = text_consistency.compute()
        if chroma_cosine is not None:
            metrics['chroma_cosine'] = chroma_cosine.compute()
        metrics = average(metrics)
        metrics = flashy.distrib.average_metrics(metrics, len(loader))

    return metrics

Evaluate audio generation with off-the-shelf metrics.

def generate_audio(self) ‑> dict
Expand source code
def generate_audio(self) -> dict:
    """Audio generation stage."""
    generate_stage_name = f'{self.current_stage}'
    sample_manager = SampleManager(self.xp)
    self.logger.info(f"Generating samples in {sample_manager.base_folder}")
    loader = self.dataloaders['generate']
    updates = len(loader)
    lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)

    dataset = get_dataset_from_loader(loader)
    dataset_duration = dataset.segment_duration
    assert dataset_duration is not None
    assert isinstance(dataset, AudioDataset)
    target_duration = self.cfg.generate.lm.gen_duration
    prompt_duration = self.cfg.generate.lm.prompt_duration
    if target_duration is None:
        target_duration = dataset_duration
    if prompt_duration is None:
        prompt_duration = dataset_duration / 4
    assert prompt_duration < dataset_duration, (
        f"Specified prompt duration ({prompt_duration}s) is longer",
        f" than reference audio duration ({dataset_duration}s)"
    )

    def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
        hydrated_conditions = []
        for sample in [x.to_condition_attributes() for x in meta]:
            cond_dict = {}
            for cond_type in sample.__annotations__.keys():
                for cond_key, cond_val in getattr(sample, cond_type).items():
                    if cond_key not in self.model.condition_provider.conditioners.keys():
                        continue
                    if is_jsonable(cond_val):
                        cond_dict[cond_key] = cond_val
                    elif isinstance(cond_val, WavCondition):
                        cond_dict[cond_key] = cond_val.path
                    elif isinstance(cond_val, JointEmbedCondition):
                        cond_dict[cond_key] = cond_val.text  # only support text at inference for now
                    else:
                        # if we reached this point, it is not clear how to log the condition
                        # so we just log the type.
                        cond_dict[cond_key] = str(type(cond_val))
                        continue
            hydrated_conditions.append(cond_dict)
        return hydrated_conditions

    metrics: dict = {}
    average = flashy.averager()
    for batch in lp:
        audio, meta = batch
        # metadata for sample manager
        hydrated_conditions = get_hydrated_conditions(meta)
        sample_generation_params = {
            **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
            **self.generation_params
        }
        if self.cfg.generate.lm.unprompted_samples:
            if self.cfg.generate.lm.gen_gt_samples:
                # get the ground truth instead of generation
                self.logger.warn(
                    "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
                gen_unprompted_audio = audio
                rtf = 1.
            else:
                gen_unprompted_outputs = self.run_generate_step(
                    batch, gen_duration=target_duration, prompt_duration=None)
                gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
                rtf = gen_unprompted_outputs['rtf']
            sample_manager.add_samples(
                gen_unprompted_audio, self.epoch, hydrated_conditions,
                ground_truth_wavs=audio, generation_args=sample_generation_params)

        if self.cfg.generate.lm.prompted_samples:
            gen_outputs = self.run_generate_step(
                batch, gen_duration=target_duration, prompt_duration=prompt_duration)
            gen_audio = gen_outputs['gen_audio'].cpu()
            prompt_audio = gen_outputs['prompt_audio'].cpu()
            sample_manager.add_samples(
                gen_audio, self.epoch, hydrated_conditions,
                prompt_wavs=prompt_audio, ground_truth_wavs=audio,
                generation_args=sample_generation_params)
        if self.cfg.generate.lm.no_text_conditioning:
            gen_outputs = self.run_generate_step(
                batch, gen_duration=target_duration, prompt_duration=None,
                remove_text_conditioning=self.cfg.generate.lm.no_text_conditioning)
            gen_audio = gen_outputs['gen_audio'].cpu()
            rtf = gen_outputs['rtf']
            # Here, the prompt is the original audio provided for the style conditioning
            prompt_audio = gen_outputs['ref_audio'].cpu()
            sample_manager.add_samples(
                gen_audio, self.epoch, hydrated_conditions,
                prompt_wavs=prompt_audio, ground_truth_wavs=audio,
                generation_args=sample_generation_params)

        metrics['rtf'] = rtf
        metrics = average(metrics)

    flashy.distrib.barrier()
    return metrics

Audio generation stage.

def get_formatter(self, stage_name: str) ‑> flashy.formatter.Formatter
Expand source code
def get_formatter(self, stage_name: str) -> flashy.Formatter:
    return flashy.Formatter({
        'lr': '.2E',
        'ce': '.3f',
        'ppl': '.3f',
        'grad_norm': '.3E',
    }, exclude_keys=['ce_q*', 'ppl_q*'])
def initialize_optimization(self) ‑> None
Expand source code
def initialize_optimization(self) -> None:
    if self.cfg.fsdp.use:
        assert not self.cfg.autocast, "Cannot use autocast with fsdp"
        self.model = self.wrap_with_fsdp(self.model)
    self.register_ema('model')
    # initialize optimization
    self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
    self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
    self.register_stateful('model', 'optimizer', 'lr_scheduler')
    self.register_best_state('model')
    self.autocast_dtype = {
        'float16': torch.float16, 'bfloat16': torch.bfloat16
    }[self.cfg.autocast_dtype]
    self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
    if self.cfg.fsdp.use:
        need_scaler = self.cfg.fsdp.param_dtype == 'float16'
    else:
        need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
    if need_scaler:
        if self.cfg.fsdp.use:
            from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
            self.scaler = ShardedGradScaler()  # type: ignore
        else:
            self.scaler = torch.cuda.amp.GradScaler()
        self.register_stateful('scaler')
def load_from_pretrained(self, name: str)
Expand source code
def load_from_pretrained(self, name: str):
    # TODO: support native HF versions of MusicGen.
    lm_pkg = models.loaders.load_lm_model_ckpt(name)
    state: dict = {
        'best_state': {
            'model': lm_pkg['best_state'],
        },
    }
    return state
def load_state_dict(self, state: dict) ‑> None
Expand source code
def load_state_dict(self, state: dict) -> None:
    if 'condition_provider' in state:
        model_state = state['model']
        condition_provider_state = state.pop('condition_provider')
        prefix = 'condition_provider.'
        for key, value in condition_provider_state.items():
            key = prefix + key
            assert key not in model_state
            model_state[key] = value
    if 'compression_model' in state:
        # We used to store the `compression_model` state in the checkpoint, however
        # this is in general not needed, as the compression model should always be readable
        # from the original `cfg.compression_model_checkpoint` location.
        compression_model_state = state.pop('compression_model')
        before_hash = model_hash(self.compression_model)
        self.compression_model.load_state_dict(compression_model_state)
        after_hash = model_hash(self.compression_model)
        if before_hash != after_hash:
            raise RuntimeError(
                "The compression model state inside the checkpoint is different"
                " from the one obtained from compression_model_checkpoint..."
                "We do not support altering the compression model inside the LM "
                "checkpoint as parts of the code, in particular for running eval post-training "
                "will use the compression_model_checkpoint as the source of truth.")

    super().load_state_dict(state)
def run_generate_step(self,
batch: Tuple[torch.Tensor, List[SegmentWithAttributes]],
gen_duration: float,
prompt_duration: float | None = None,
remove_text_conditioning: bool = False) ‑> dict
Expand source code
@torch.no_grad()
def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
                      gen_duration: float, prompt_duration: tp.Optional[float] = None,
                      remove_text_conditioning: bool = False,
                      ) -> dict:
    """Run generate step on a batch of optional audio tensor and corresponding attributes.

    Args:
        batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
        use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
        gen_duration (float): Target audio duration for the generation.
        prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
    Returns:
        gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
            and the prompt along with additional information.
    """
    bench_start = time.time()
    audio, meta = batch
    assert audio.size(0) == len(meta), (
        f"Mismatch between number of items in audio batch ({audio.size(0)})",
        f" and in metadata ({len(meta)})"
    )
    # prepare attributes
    attributes = [x.to_condition_attributes() for x in meta]
    if remove_text_conditioning:
        attributes = _drop_description_condition(attributes)

    # prepare audio prompt
    if prompt_duration is None:
        prompt_audio = None
    else:
        assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
        prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
        prompt_audio = audio[..., :prompt_audio_frames]

    # get audio tokens from compression model
    if prompt_audio is None or prompt_audio.nelement() == 0:
        num_samples = len(attributes)
        prompt_tokens = None
    else:
        num_samples = None
        prompt_audio = prompt_audio.to(self.device)
        prompt_tokens, scale = self.compression_model.encode(prompt_audio)
        assert scale is None, "Compression model in MusicGen should not require rescaling."

    # generate by sampling from the LM
    with self.autocast:
        total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
        gen_tokens = self.model.generate(
            prompt_tokens, attributes, max_gen_len=total_gen_len,
            num_samples=num_samples, **self.generation_params)

    # generate audio from tokens
    assert gen_tokens.dim() == 3
    gen_audio = self.compression_model.decode(gen_tokens, None)

    bench_end = time.time()
    gen_outputs = {
        'rtf': (bench_end - bench_start) / gen_duration,
        'ref_audio': audio,
        'gen_audio': gen_audio,
        'gen_tokens': gen_tokens,
        'prompt_audio': prompt_audio,
        'prompt_tokens': prompt_tokens,
    }
    return gen_outputs

Run generate step on a batch of optional audio tensor and corresponding attributes.

Args

batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
use_prompt : bool
Whether to do audio continuation generation with prompt from audio batch.
gen_duration : float
Target audio duration for the generation.
prompt_duration : float, optional
Duration for the audio prompt to use for continuation.

Returns

gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation and the prompt along with additional information.

def show(self) ‑> None
Expand source code
def show(self) -> None:
    """Show the compression model and LM model."""
    self.logger.info("Compression model:")
    self.log_model_summary(self.compression_model)
    self.logger.info("LM model:")
    self.log_model_summary(self.model)

Show the compression model and LM model.

Inherited members