Module audiocraft.solvers.compression
Functions
def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.dictconfig.DictConfig) ‑> dict
-
Expand source code
def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict: """Audio reconstruction evaluation method that can be conveniently pickled.""" metrics = {} if cfg.evaluate.metrics.visqol: visqol = builders.get_visqol(cfg.metrics.visqol) metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate) sisnr = builders.get_loss('sisnr', cfg) metrics['sisnr'] = sisnr(y_pred, y) return metrics
Audio reconstruction evaluation method that can be conveniently pickled.
Classes
class CompressionSolver (cfg: omegaconf.dictconfig.DictConfig)
-
Expand source code
class CompressionSolver(base.StandardSolver): """Solver for compression task. The compression task combines a set of perceptual and objective losses to train an EncodecModel (composed of an encoder-decoder and a quantizer) to perform high fidelity audio reconstruction. """ def __init__(self, cfg: omegaconf.DictConfig): super().__init__(cfg) self.rng: torch.Generator # set at each epoch self.adv_losses = builders.get_adversarial_losses(self.cfg) self.aux_losses = nn.ModuleDict() self.info_losses = nn.ModuleDict() assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver." loss_weights = dict() for loss_name, weight in self.cfg.losses.items(): if loss_name in ['adv', 'feat']: for adv_name, _ in self.adv_losses.items(): loss_weights[f'{loss_name}_{adv_name}'] = weight elif weight > 0: self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg) loss_weights[loss_name] = weight else: self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg) self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer) self.register_stateful('adv_losses') @property def best_metric_name(self) -> tp.Optional[str]: # best model is the last for the compression model return None def build_model(self): """Instantiate model and optimizer.""" # Model and optimizer self.model = models.builders.get_compression_model(self.cfg).to(self.device) self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) self.register_stateful('model', 'optimizer') self.register_best_state('model') self.register_ema('model') def build_dataloaders(self): """Instantiate audio dataloaders for each stage.""" self.dataloaders = builders.get_audio_datasets(self.cfg) def show(self): """Show the compression model and employed adversarial loss.""" self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:") self.log_model_summary(self.model) self.logger.info("Adversarial loss:") self.log_model_summary(self.adv_losses) self.logger.info("Auxiliary losses:") self.logger.info(self.aux_losses) self.logger.info("Info losses:") self.logger.info(self.info_losses) def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): """Perform one training or valid step on a given batch.""" x = batch.to(self.device) y = x.clone() qres = self.model(x) assert isinstance(qres, quantization.QuantizedResult) y_pred = qres.x # Log bandwidth in kb/s metrics['bandwidth'] = qres.bandwidth.mean() if self.is_training: d_losses: dict = {} if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every: for adv_name, adversary in self.adv_losses.items(): disc_loss = adversary.train_adv(y_pred, y) d_losses[f'd_{adv_name}'] = disc_loss metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values()))) metrics.update(d_losses) balanced_losses: dict = {} other_losses: dict = {} # penalty from quantization if qres.penalty is not None and qres.penalty.requires_grad: other_losses['penalty'] = qres.penalty # penalty term from the quantizer # adversarial losses for adv_name, adversary in self.adv_losses.items(): adv_loss, feat_loss = adversary(y_pred, y) balanced_losses[f'adv_{adv_name}'] = adv_loss balanced_losses[f'feat_{adv_name}'] = feat_loss # auxiliary losses for loss_name, criterion in self.aux_losses.items(): loss = criterion(y_pred, y) balanced_losses[loss_name] = loss # weighted losses metrics.update(balanced_losses) metrics.update(other_losses) metrics.update(qres.metrics) if self.is_training: # backprop losses that are not handled by balancer other_loss = torch.tensor(0., device=self.device) if 'penalty' in other_losses: other_loss += other_losses['penalty'] if other_loss.requires_grad: other_loss.backward(retain_graph=True) ratio1 = sum(p.grad.data.norm(p=2).pow(2) for p in self.model.parameters() if p.grad is not None) assert isinstance(ratio1, torch.Tensor) metrics['ratio1'] = ratio1.sqrt() # balancer losses backward, returns effective training loss # with effective weights at the current batch. metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred) # add metrics corresponding to weight ratios metrics.update(self.balancer.metrics) ratio2 = sum(p.grad.data.norm(p=2).pow(2) for p in self.model.parameters() if p.grad is not None) assert isinstance(ratio2, torch.Tensor) metrics['ratio2'] = ratio2.sqrt() # optim flashy.distrib.sync_model(self.model) if self.cfg.optim.max_norm: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.optim.max_norm ) self.optimizer.step() self.optimizer.zero_grad() # informative losses only info_losses: dict = {} with torch.no_grad(): for loss_name, criterion in self.info_losses.items(): loss = criterion(y_pred, y) info_losses[loss_name] = loss metrics.update(info_losses) # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')] if len(adv_losses) > 0: metrics['adv'] = torch.sum(torch.stack(adv_losses)) feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')] if len(feat_losses) > 0: metrics['feat'] = torch.sum(torch.stack(feat_losses)) return metrics def run_epoch(self): # reset random seed at the beginning of the epoch self.rng = torch.Generator() self.rng.manual_seed(1234 + self.epoch) # run epoch super().run_epoch() def evaluate(self): """Evaluate stage. Runs audio reconstruction evaluation.""" self.model.eval() evaluate_stage_name = str(self.current_stage) loader = self.dataloaders['evaluate'] updates = len(loader) lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) average = flashy.averager() pendings = [] ctx = multiprocessing.get_context('spawn') with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: for idx, batch in enumerate(lp): x = batch.to(self.device) with torch.no_grad(): qres = self.model(x) y_pred = qres.x.cpu() y = batch.cpu() # should already be on CPU but just in case pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)) metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates) for pending in metrics_lp: metrics = pending.result() metrics = average(metrics) metrics = flashy.distrib.average_metrics(metrics, len(loader)) return metrics def generate(self): """Generate stage.""" self.model.eval() sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) generate_stage_name = str(self.current_stage) loader = self.dataloaders['generate'] updates = len(loader) lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) for batch in lp: reference, _ = batch reference = reference.to(self.device) with torch.no_grad(): qres = self.model(reference) assert isinstance(qres, quantization.QuantizedResult) reference = reference.cpu() estimate = qres.x.cpu() sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) flashy.distrib.barrier() def load_from_pretrained(self, name: str) -> dict: model = models.CompressionModel.get_pretrained(name) if isinstance(model, models.DAC): raise RuntimeError("Cannot fine tune a DAC model.") elif isinstance(model, models.HFEncodecCompressionModel): self.logger.warning('Trying to automatically convert a HuggingFace model ' 'to AudioCraft, this might fail!') state = model.model.state_dict() new_state = {} for k, v in state.items(): if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k: # We need to determine if this a convtr or a regular conv. layer = int(k.split('.')[2]) if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d): k = k.replace('.conv.', '.convtr.') k = k.replace('encoder.layers.', 'encoder.model.') k = k.replace('decoder.layers.', 'decoder.model.') k = k.replace('conv.', 'conv.conv.') k = k.replace('convtr.', 'convtr.convtr.') k = k.replace('quantizer.layers.', 'quantizer.vq.layers.') k = k.replace('.codebook.', '._codebook.') new_state[k] = v state = new_state elif isinstance(model, models.EncodecModel): state = model.state_dict() else: raise RuntimeError(f"Cannot fine tune model type {type(model)}.") return { 'best_state': {'model': state} } @staticmethod def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: """Instantiate a CompressionModel from a given checkpoint path or dora sig. This method is a convenient endpoint to load a CompressionModel to use in other solvers. Args: checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. This also supports pre-trained models by using a path of the form //pretrained/NAME. See `model_from_pretrained` for a list of supported pretrained models. use_ema (bool): Use EMA variant of the model instead of the actual model. device (torch.device or str): Device on which the model is loaded. """ checkpoint_path = str(checkpoint_path) if checkpoint_path.startswith('//pretrained/'): name = checkpoint_path.split('/', 3)[-1] return models.CompressionModel.get_pretrained(name, device) logger = logging.getLogger(__name__) logger.info(f"Loading compression model from checkpoint: {checkpoint_path}") _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False) assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}" state = checkpoint.load_checkpoint(_checkpoint_path) assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}" cfg = state['xp.cfg'] cfg.device = device compression_model = models.builders.get_compression_model(cfg).to(device) assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" assert 'best_state' in state and state['best_state'] != {} assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix." compression_model.load_state_dict(state['best_state']['model']) compression_model.eval() logger.info("Compression model loaded!") return compression_model @staticmethod def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, checkpoint_path: tp.Union[Path, str], device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig. Args: cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode. checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. use_ema (bool): Use EMA variant of the model instead of the actual model. device (torch.device or str): Device on which the model is loaded. """ compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device) compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg) return compression_model
Solver for compression task.
The compression task combines a set of perceptual and objective losses to train an EncodecModel (composed of an encoder-decoder and a quantizer) to perform high fidelity audio reconstruction.
Ancestors
- StandardSolver
- abc.ABC
- flashy.solver.BaseSolver
Static methods
def model_from_checkpoint(checkpoint_path: str | pathlib.Path, device: torch.device | str = 'cpu') ‑> CompressionModel
-
Expand source code
@staticmethod def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: """Instantiate a CompressionModel from a given checkpoint path or dora sig. This method is a convenient endpoint to load a CompressionModel to use in other solvers. Args: checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. This also supports pre-trained models by using a path of the form //pretrained/NAME. See `model_from_pretrained` for a list of supported pretrained models. use_ema (bool): Use EMA variant of the model instead of the actual model. device (torch.device or str): Device on which the model is loaded. """ checkpoint_path = str(checkpoint_path) if checkpoint_path.startswith('//pretrained/'): name = checkpoint_path.split('/', 3)[-1] return models.CompressionModel.get_pretrained(name, device) logger = logging.getLogger(__name__) logger.info(f"Loading compression model from checkpoint: {checkpoint_path}") _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False) assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}" state = checkpoint.load_checkpoint(_checkpoint_path) assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}" cfg = state['xp.cfg'] cfg.device = device compression_model = models.builders.get_compression_model(cfg).to(device) assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" assert 'best_state' in state and state['best_state'] != {} assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix." compression_model.load_state_dict(state['best_state']['model']) compression_model.eval() logger.info("Compression model loaded!") return compression_model
Instantiate a CompressionModel from a given checkpoint path or dora sig. This method is a convenient endpoint to load a CompressionModel to use in other solvers.
Args
checkpoint_path
:Path
orstr
- Path to checkpoint or dora sig from where the checkpoint is resolved.
This also supports pre-trained models by using a path of the form //pretrained/NAME.
See
model_from_pretrained
for a list of supported pretrained models. use_ema
:bool
- Use EMA variant of the model instead of the actual model.
device
:torch.device
orstr
- Device on which the model is loaded.
def wrapped_model_from_checkpoint(cfg: omegaconf.dictconfig.DictConfig,
checkpoint_path: str | pathlib.Path,
device: torch.device | str = 'cpu') ‑> CompressionModel-
Expand source code
@staticmethod def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, checkpoint_path: tp.Union[Path, str], device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig. Args: cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode. checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. use_ema (bool): Use EMA variant of the model instead of the actual model. device (torch.device or str): Device on which the model is loaded. """ compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device) compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg) return compression_model
Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
Args
cfg
:omegaconf.DictConfig
- Configuration to read from for wrapped mode.
checkpoint_path
:Path
orstr
- Path to checkpoint or dora sig from where the checkpoint is resolved.
use_ema
:bool
- Use EMA variant of the model instead of the actual model.
device
:torch.device
orstr
- Device on which the model is loaded.
Methods
def build_dataloaders(self)
-
Expand source code
def build_dataloaders(self): """Instantiate audio dataloaders for each stage.""" self.dataloaders = builders.get_audio_datasets(self.cfg)
Instantiate audio dataloaders for each stage.
def build_model(self)
-
Expand source code
def build_model(self): """Instantiate model and optimizer.""" # Model and optimizer self.model = models.builders.get_compression_model(self.cfg).to(self.device) self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) self.register_stateful('model', 'optimizer') self.register_best_state('model') self.register_ema('model')
Instantiate model and optimizer.
def evaluate(self)
-
Expand source code
def evaluate(self): """Evaluate stage. Runs audio reconstruction evaluation.""" self.model.eval() evaluate_stage_name = str(self.current_stage) loader = self.dataloaders['evaluate'] updates = len(loader) lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) average = flashy.averager() pendings = [] ctx = multiprocessing.get_context('spawn') with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: for idx, batch in enumerate(lp): x = batch.to(self.device) with torch.no_grad(): qres = self.model(x) y_pred = qres.x.cpu() y = batch.cpu() # should already be on CPU but just in case pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)) metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates) for pending in metrics_lp: metrics = pending.result() metrics = average(metrics) metrics = flashy.distrib.average_metrics(metrics, len(loader)) return metrics
Evaluate stage. Runs audio reconstruction evaluation.
def load_from_pretrained(self, name: str) ‑> dict
-
Expand source code
def load_from_pretrained(self, name: str) -> dict: model = models.CompressionModel.get_pretrained(name) if isinstance(model, models.DAC): raise RuntimeError("Cannot fine tune a DAC model.") elif isinstance(model, models.HFEncodecCompressionModel): self.logger.warning('Trying to automatically convert a HuggingFace model ' 'to AudioCraft, this might fail!') state = model.model.state_dict() new_state = {} for k, v in state.items(): if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k: # We need to determine if this a convtr or a regular conv. layer = int(k.split('.')[2]) if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d): k = k.replace('.conv.', '.convtr.') k = k.replace('encoder.layers.', 'encoder.model.') k = k.replace('decoder.layers.', 'decoder.model.') k = k.replace('conv.', 'conv.conv.') k = k.replace('convtr.', 'convtr.convtr.') k = k.replace('quantizer.layers.', 'quantizer.vq.layers.') k = k.replace('.codebook.', '._codebook.') new_state[k] = v state = new_state elif isinstance(model, models.EncodecModel): state = model.state_dict() else: raise RuntimeError(f"Cannot fine tune model type {type(model)}.") return { 'best_state': {'model': state} }
def show(self)
-
Expand source code
def show(self): """Show the compression model and employed adversarial loss.""" self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:") self.log_model_summary(self.model) self.logger.info("Adversarial loss:") self.log_model_summary(self.adv_losses) self.logger.info("Auxiliary losses:") self.logger.info(self.aux_losses) self.logger.info("Info losses:") self.logger.info(self.info_losses)
Show the compression model and employed adversarial loss.
Inherited members
StandardSolver
:autocast
best_metric_name
commit
common_train_valid
generate
get_eval_solver_from_sig
initialize_ema
load_checkpoints
log_model_summary
register_best_state
register_ema
restore
run
run_epoch
run_one_stage
run_step
save_checkpoints
should_run_stage
should_stop_training
train
update_best_state_from_stage
valid