Module audiocraft.train

Entry point for dora to launch solvers for running training loops. See more info on how to use dora: https://github.com/facebookresearch/dora

Expand source code
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Entry point for dora to launch solvers for running training loops.
See more info on how to use dora: https://github.com/facebookresearch/dora
"""

import logging
import multiprocessing
import os
from pathlib import Path
import sys
import typing as tp

from dora import git_save, hydra_main, XP
import flashy
import hydra
import omegaconf

from .environment import AudioCraftEnvironment
from .utils.cluster import get_slurm_parameters

logger = logging.getLogger(__name__)


def resolve_config_dset_paths(cfg):
    """Enable Dora to load manifest from git clone repository."""
    # manifest files for the different splits
    for key, value in cfg.datasource.items():
        if isinstance(value, str):
            cfg.datasource[key] = git_save.to_absolute_path(value)


def get_solver(cfg):
    from . import solvers
    # Convert batch size to batch size for each GPU
    assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0
    cfg.dataset.batch_size //= flashy.distrib.world_size()
    for split in ['train', 'valid', 'evaluate', 'generate']:
        if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'):
            assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0
            cfg.dataset[split].batch_size //= flashy.distrib.world_size()
    resolve_config_dset_paths(cfg)
    solver = solvers.get_solver(cfg)
    return solver


def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
                       restore: bool = True, load_best: bool = True,
                       ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True):
    """Given a XP, return the Solver object.

    Args:
        xp (XP): Dora experiment for which to retrieve the solver.
        override_cfg (dict or None): If not None, should be a dict used to
            override some values in the config of `xp`. This will not impact
            the XP signature or folder. The format is different
            than the one used in Dora grids, nested keys should actually be nested dicts,
            not flattened, e.g. `{'optim': {'batch_size': 32}}`.
        restore (bool): If `True` (the default), restore state from the last checkpoint.
        load_best (bool): If `True` (the default), load the best state from the checkpoint.
        ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`.
        disable_fsdp (bool): if True, disables FSDP entirely. This will
            also automatically skip loading the EMA. For solver specific
            state sources, like the optimizer, you might want to
            use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`.
    """
    logger.info(f"Loading solver from XP {xp.sig}. "
                f"Overrides used: {xp.argv}")
    cfg = xp.cfg
    if override_cfg is not None:
        cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg))
    if disable_fsdp and cfg.fsdp.use:
        cfg.fsdp.use = False
        assert load_best is True
        # ignoring some keys that were FSDP sharded like model, ema, and best_state.
        # fsdp_best_state will be used in that case. When using a specific solver,
        # one is responsible for adding the relevant keys, e.g. 'optimizer'.
        # We could make something to automatically register those inside the solver, but that
        # seem overkill at this point.
        ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state']

    try:
        with xp.enter():
            solver = get_solver(cfg)
            if restore:
                solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys)
        return solver
    finally:
        hydra.core.global_hydra.GlobalHydra.instance().clear()


def get_solver_from_sig(sig: str, *args, **kwargs):
    """Return Solver object from Dora signature, i.e. to play with it from a notebook.
    See `get_solver_from_xp` for more information.
    """
    xp = main.get_xp_from_sig(sig)
    return get_solver_from_xp(xp, *args, **kwargs)


def init_seed_and_system(cfg):
    import numpy as np
    import torch
    import random
    from audiocraft.modules.transformer import set_efficient_attention_backend

    multiprocessing.set_start_method(cfg.mp_start_method)
    logger.debug('Setting mp start method to %s', cfg.mp_start_method)
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    # torch also initialize cuda seed if available
    torch.manual_seed(cfg.seed)
    torch.set_num_threads(cfg.num_threads)
    os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads)
    os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads)
    logger.debug('Setting num threads to %d', cfg.num_threads)
    set_efficient_attention_backend(cfg.efficient_attention_backend)
    logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
    if 'SLURM_JOB_ID' in os.environ:
        tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID'])
        if tmpdir.exists():
            logger.info("Changing tmpdir to %s", tmpdir)
            os.environ['TMPDIR'] = str(tmpdir)


@hydra_main(config_path='../config', config_name='config', version_base='1.1')
def main(cfg):
    init_seed_and_system(cfg)

    # Setup logging both to XP specific folder, and to stderr.
    log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}'
    flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name)
    # Initialize distributed training, no need to specify anything when using Dora.
    flashy.distrib.init()
    solver = get_solver(cfg)
    if cfg.show:
        solver.show()
        return

    if cfg.execute_only:
        assert cfg.execute_inplace or cfg.continue_from is not None, \
            "Please explicitly specify the checkpoint to continue from with continue_from=<sig_or_path> " + \
            "when running with execute_only or set execute_inplace to True."
        solver.restore(replay_metrics=False)  # load checkpoint
        solver.run_one_stage(cfg.execute_only)
        return

    return solver.run()


main.dora.dir = AudioCraftEnvironment.get_dora_dir()
main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm)

if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK):
    print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr)
    main.dora.shared = None

if __name__ == '__main__':
    main()

Functions

def get_solver(cfg)
Expand source code
def get_solver(cfg):
    from . import solvers
    # Convert batch size to batch size for each GPU
    assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0
    cfg.dataset.batch_size //= flashy.distrib.world_size()
    for split in ['train', 'valid', 'evaluate', 'generate']:
        if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'):
            assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0
            cfg.dataset[split].batch_size //= flashy.distrib.world_size()
    resolve_config_dset_paths(cfg)
    solver = solvers.get_solver(cfg)
    return solver
def get_solver_from_sig(sig: str, *args, **kwargs)

Return Solver object from Dora signature, i.e. to play with it from a notebook. See get_solver_from_xp() for more information.

Expand source code
def get_solver_from_sig(sig: str, *args, **kwargs):
    """Return Solver object from Dora signature, i.e. to play with it from a notebook.
    See `get_solver_from_xp` for more information.
    """
    xp = main.get_xp_from_sig(sig)
    return get_solver_from_xp(xp, *args, **kwargs)
def get_solver_from_xp(xp: dora.xp.XP, override_cfg: Union[dict, omegaconf.dictconfig.DictConfig, None] = None, restore: bool = True, load_best: bool = True, ignore_state_keys: List[str] = [], disable_fsdp: bool = True)

Given a XP, return the Solver object.

Args

xp : XP
Dora experiment for which to retrieve the solver.
override_cfg : dict or None
If not None, should be a dict used to override some values in the config of xp. This will not impact the XP signature or folder. The format is different than the one used in Dora grids, nested keys should actually be nested dicts, not flattened, e.g. {'optim': {'batch_size': 32}}.
restore : bool
If True (the default), restore state from the last checkpoint.
load_best : bool
If True (the default), load the best state from the checkpoint.
ignore_state_keys : list[str]
List of sources to ignore when loading the state, e.g. optimizer.
disable_fsdp : bool
if True, disables FSDP entirely. This will also automatically skip loading the EMA. For solver specific state sources, like the optimizer, you might want to use along ignore_state_keys=['optimizer']. Must be used with load_best=True.
Expand source code
def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
                       restore: bool = True, load_best: bool = True,
                       ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True):
    """Given a XP, return the Solver object.

    Args:
        xp (XP): Dora experiment for which to retrieve the solver.
        override_cfg (dict or None): If not None, should be a dict used to
            override some values in the config of `xp`. This will not impact
            the XP signature or folder. The format is different
            than the one used in Dora grids, nested keys should actually be nested dicts,
            not flattened, e.g. `{'optim': {'batch_size': 32}}`.
        restore (bool): If `True` (the default), restore state from the last checkpoint.
        load_best (bool): If `True` (the default), load the best state from the checkpoint.
        ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`.
        disable_fsdp (bool): if True, disables FSDP entirely. This will
            also automatically skip loading the EMA. For solver specific
            state sources, like the optimizer, you might want to
            use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`.
    """
    logger.info(f"Loading solver from XP {xp.sig}. "
                f"Overrides used: {xp.argv}")
    cfg = xp.cfg
    if override_cfg is not None:
        cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg))
    if disable_fsdp and cfg.fsdp.use:
        cfg.fsdp.use = False
        assert load_best is True
        # ignoring some keys that were FSDP sharded like model, ema, and best_state.
        # fsdp_best_state will be used in that case. When using a specific solver,
        # one is responsible for adding the relevant keys, e.g. 'optimizer'.
        # We could make something to automatically register those inside the solver, but that
        # seem overkill at this point.
        ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state']

    try:
        with xp.enter():
            solver = get_solver(cfg)
            if restore:
                solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys)
        return solver
    finally:
        hydra.core.global_hydra.GlobalHydra.instance().clear()
def init_seed_and_system(cfg)
Expand source code
def init_seed_and_system(cfg):
    import numpy as np
    import torch
    import random
    from audiocraft.modules.transformer import set_efficient_attention_backend

    multiprocessing.set_start_method(cfg.mp_start_method)
    logger.debug('Setting mp start method to %s', cfg.mp_start_method)
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    # torch also initialize cuda seed if available
    torch.manual_seed(cfg.seed)
    torch.set_num_threads(cfg.num_threads)
    os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads)
    os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads)
    logger.debug('Setting num threads to %d', cfg.num_threads)
    set_efficient_attention_backend(cfg.efficient_attention_backend)
    logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
    if 'SLURM_JOB_ID' in os.environ:
        tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID'])
        if tmpdir.exists():
            logger.info("Changing tmpdir to %s", tmpdir)
            os.environ['TMPDIR'] = str(tmpdir)
def resolve_config_dset_paths(cfg)

Enable Dora to load manifest from git clone repository.

Expand source code
def resolve_config_dset_paths(cfg):
    """Enable Dora to load manifest from git clone repository."""
    # manifest files for the different splits
    for key, value in cfg.datasource.items():
        if isinstance(value, str):
            cfg.datasource[key] = git_save.to_absolute_path(value)