Module audiocraft.train
Entry point for dora to launch solvers for running training loops. See more info on how to use dora: https://github.com/facebookresearch/dora
Functions
def get_solver(cfg)-
Expand source code
def get_solver(cfg): from . import solvers # Convert batch size to batch size for each GPU assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0 cfg.dataset.batch_size //= flashy.distrib.world_size() for split in ['train', 'valid', 'evaluate', 'generate']: if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'): assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0 cfg.dataset[split].batch_size //= flashy.distrib.world_size() resolve_config_dset_paths(cfg) solver = solvers.get_solver(cfg) return solver def get_solver_from_sig(sig: str, *args, **kwargs)-
Expand source code
def get_solver_from_sig(sig: str, *args, **kwargs): """Return Solver object from Dora signature, i.e. to play with it from a notebook. See `get_solver_from_xp` for more information. """ xp = main.get_xp_from_sig(sig) return get_solver_from_xp(xp, *args, **kwargs)Return Solver object from Dora signature, i.e. to play with it from a notebook. See
get_solver_from_xp()for more information. def get_solver_from_xp(xp: dora.xp.XP,
override_cfg: dict | omegaconf.dictconfig.DictConfig | None = None,
restore: bool = True,
load_best: bool = True,
ignore_state_keys: List[str] = [],
disable_fsdp: bool = True)-
Expand source code
def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, restore: bool = True, load_best: bool = True, ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True): """Given a XP, return the Solver object. Args: xp (XP): Dora experiment for which to retrieve the solver. override_cfg (dict or None): If not None, should be a dict used to override some values in the config of `xp`. This will not impact the XP signature or folder. The format is different than the one used in Dora grids, nested keys should actually be nested dicts, not flattened, e.g. `{'optim': {'batch_size': 32}}`. restore (bool): If `True` (the default), restore state from the last checkpoint. load_best (bool): If `True` (the default), load the best state from the checkpoint. ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`. disable_fsdp (bool): if True, disables FSDP entirely. This will also automatically skip loading the EMA. For solver specific state sources, like the optimizer, you might want to use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`. """ logger.info(f"Loading solver from XP {xp.sig}. " f"Overrides used: {xp.argv}") cfg = xp.cfg if override_cfg is not None: cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg)) if disable_fsdp and cfg.fsdp.use: cfg.fsdp.use = False assert load_best is True # ignoring some keys that were FSDP sharded like model, ema, and best_state. # fsdp_best_state will be used in that case. When using a specific solver, # one is responsible for adding the relevant keys, e.g. 'optimizer'. # We could make something to automatically register those inside the solver, but that # seem overkill at this point. ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state'] try: with xp.enter(): solver = get_solver(cfg) if restore: solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys) return solver finally: hydra.core.global_hydra.GlobalHydra.instance().clear()Given a XP, return the Solver object.
Args
xp:XP- Dora experiment for which to retrieve the solver.
override_cfg:dictorNone- If not None, should be a dict used to
override some values in the config of
xp. This will not impact the XP signature or folder. The format is different than the one used in Dora grids, nested keys should actually be nested dicts, not flattened, e.g.{'optim': {'batch_size': 32}}. restore:bool- If
True(the default), restore state from the last checkpoint. load_best:bool- If
True(the default), load the best state from the checkpoint. ignore_state_keys:list[str]- List of sources to ignore when loading the state, e.g.
optimizer. disable_fsdp:bool- if True, disables FSDP entirely. This will
also automatically skip loading the EMA. For solver specific
state sources, like the optimizer, you might want to
use along
ignore_state_keys=['optimizer']. Must be used withload_best=True.
def init_seed_and_system(cfg)-
Expand source code
def init_seed_and_system(cfg): import numpy as np import torch import random from audiocraft.modules.transformer import set_efficient_attention_backend multiprocessing.set_start_method(cfg.mp_start_method) logger.debug('Setting mp start method to %s', cfg.mp_start_method) random.seed(cfg.seed) np.random.seed(cfg.seed) # torch also initialize cuda seed if available torch.manual_seed(cfg.seed) torch.set_num_threads(cfg.num_threads) os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads) os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads) logger.debug('Setting num threads to %d', cfg.num_threads) set_efficient_attention_backend(cfg.efficient_attention_backend) logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend) if 'SLURM_JOB_ID' in os.environ: tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID']) if tmpdir.exists(): logger.info("Changing tmpdir to %s", tmpdir) os.environ['TMPDIR'] = str(tmpdir) def resolve_config_dset_paths(cfg)-
Expand source code
def resolve_config_dset_paths(cfg): """Enable Dora to load manifest from git clone repository.""" # manifest files for the different splits for key, value in cfg.datasource.items(): if isinstance(value, str): cfg.datasource[key] = git_save.to_absolute_path(value)Enable Dora to load manifest from git clone repository.