Module audiocraft.train
Entry point for dora to launch solvers for running training loops. See more info on how to use dora: https://github.com/facebookresearch/dora
Functions
def get_solver(cfg)
-
Expand source code
def get_solver(cfg): from . import solvers # Convert batch size to batch size for each GPU assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0 cfg.dataset.batch_size //= flashy.distrib.world_size() for split in ['train', 'valid', 'evaluate', 'generate']: if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'): assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0 cfg.dataset[split].batch_size //= flashy.distrib.world_size() resolve_config_dset_paths(cfg) solver = solvers.get_solver(cfg) return solver
def get_solver_from_sig(sig: str, *args, **kwargs)
-
Expand source code
def get_solver_from_sig(sig: str, *args, **kwargs): """Return Solver object from Dora signature, i.e. to play with it from a notebook. See `get_solver_from_xp` for more information. """ xp = main.get_xp_from_sig(sig) return get_solver_from_xp(xp, *args, **kwargs)
Return Solver object from Dora signature, i.e. to play with it from a notebook. See
get_solver_from_xp()
for more information. def get_solver_from_xp(xp: dora.xp.XP,
override_cfg: dict | omegaconf.dictconfig.DictConfig | None = None,
restore: bool = True,
load_best: bool = True,
ignore_state_keys: List[str] = [],
disable_fsdp: bool = True)-
Expand source code
def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, restore: bool = True, load_best: bool = True, ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True): """Given a XP, return the Solver object. Args: xp (XP): Dora experiment for which to retrieve the solver. override_cfg (dict or None): If not None, should be a dict used to override some values in the config of `xp`. This will not impact the XP signature or folder. The format is different than the one used in Dora grids, nested keys should actually be nested dicts, not flattened, e.g. `{'optim': {'batch_size': 32}}`. restore (bool): If `True` (the default), restore state from the last checkpoint. load_best (bool): If `True` (the default), load the best state from the checkpoint. ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`. disable_fsdp (bool): if True, disables FSDP entirely. This will also automatically skip loading the EMA. For solver specific state sources, like the optimizer, you might want to use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`. """ logger.info(f"Loading solver from XP {xp.sig}. " f"Overrides used: {xp.argv}") cfg = xp.cfg if override_cfg is not None: cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg)) if disable_fsdp and cfg.fsdp.use: cfg.fsdp.use = False assert load_best is True # ignoring some keys that were FSDP sharded like model, ema, and best_state. # fsdp_best_state will be used in that case. When using a specific solver, # one is responsible for adding the relevant keys, e.g. 'optimizer'. # We could make something to automatically register those inside the solver, but that # seem overkill at this point. ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state'] try: with xp.enter(): solver = get_solver(cfg) if restore: solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys) return solver finally: hydra.core.global_hydra.GlobalHydra.instance().clear()
Given a XP, return the Solver object.
Args
xp
:XP
- Dora experiment for which to retrieve the solver.
override_cfg
:dict
orNone
- If not None, should be a dict used to
override some values in the config of
xp
. This will not impact the XP signature or folder. The format is different than the one used in Dora grids, nested keys should actually be nested dicts, not flattened, e.g.{'optim': {'batch_size': 32}}
. restore
:bool
- If
True
(the default), restore state from the last checkpoint. load_best
:bool
- If
True
(the default), load the best state from the checkpoint. ignore_state_keys
:list[str]
- List of sources to ignore when loading the state, e.g.
optimizer
. disable_fsdp
:bool
- if True, disables FSDP entirely. This will
also automatically skip loading the EMA. For solver specific
state sources, like the optimizer, you might want to
use along
ignore_state_keys=['optimizer']
. Must be used withload_best=True
.
def init_seed_and_system(cfg)
-
Expand source code
def init_seed_and_system(cfg): import numpy as np import torch import random from audiocraft.modules.transformer import set_efficient_attention_backend multiprocessing.set_start_method(cfg.mp_start_method) logger.debug('Setting mp start method to %s', cfg.mp_start_method) random.seed(cfg.seed) np.random.seed(cfg.seed) # torch also initialize cuda seed if available torch.manual_seed(cfg.seed) torch.set_num_threads(cfg.num_threads) os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads) os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads) logger.debug('Setting num threads to %d', cfg.num_threads) set_efficient_attention_backend(cfg.efficient_attention_backend) logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend) if 'SLURM_JOB_ID' in os.environ: tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID']) if tmpdir.exists(): logger.info("Changing tmpdir to %s", tmpdir) os.environ['TMPDIR'] = str(tmpdir)
def resolve_config_dset_paths(cfg)
-
Expand source code
def resolve_config_dset_paths(cfg): """Enable Dora to load manifest from git clone repository.""" # manifest files for the different splits for key, value in cfg.datasource.items(): if isinstance(value, str): cfg.datasource[key] = git_save.to_absolute_path(value)
Enable Dora to load manifest from git clone repository.