Module audiocraft.optim.fsdp

Wrapper around FSDP for more convenient use in the training loops.

Functions

def is_fsdp_used() ‑> bool
Expand source code
def is_fsdp_used() -> bool:
    """Return whether we are using FSDP."""
    # A bit of a hack but should work from anywhere.
    if dora.is_xp():
        cfg = dora.get_xp().cfg
        if hasattr(cfg, 'fsdp'):
            return cfg.fsdp.use
    return False

Return whether we are using FSDP.

def is_sharded_tensor(x: Any) ‑> bool
Expand source code
def is_sharded_tensor(x: tp.Any) -> bool:
    return isinstance(x, ShardedTensor)
def purge_fsdp(model: torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel)
Expand source code
def purge_fsdp(model: FSDP):
    """Purge the FSDP cached shard inside the model. This should
    allow setting the best state or switching to the EMA.
    """
    from torch.distributed.fsdp._runtime_utils import _reshard  # type: ignore
    for module in FSDP.fsdp_modules(model):
        if hasattr(module, "_handles"):
            # support for FSDP with torch<2.1.0
            handles = module._handles
            if not handles:
                continue
            handle = handles[0]
            unsharded_flat_param = handle._get_padded_unsharded_flat_param()
            storage_size: int = unsharded_flat_param._typed_storage()._size()  # type: ignore
            if storage_size == 0:
                continue
            true_list = [True for h in handles]
            _reshard(module, handles, true_list)
        else:
            handle = module._handle
            if not handle:
                continue
            unsharded_flat_param = handle._get_padded_unsharded_flat_param()
            storage_size: int = unsharded_flat_param._typed_storage()._size()  # type: ignore
            if storage_size == 0:
                continue
            _reshard(module, handle, True)

Purge the FSDP cached shard inside the model. This should allow setting the best state or switching to the EMA.

def switch_to_full_state_dict(models: List[torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel])
Expand source code
@contextmanager
def switch_to_full_state_dict(models: tp.List[FSDP]):
    # Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
    # so let's do thing manually.
    for model in models:
        FSDP.set_state_dict_type(  # type: ignore
            model, StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(offload_to_cpu=True, rank0_only=True))
    try:
        yield
    finally:
        for model in models:
            FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT)  # type: ignore
def wrap_with_fsdp(cfg,
model: torch.nn.modules.module.Module,
block_classes: Set[Type] | None = None) ‑> torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel
Expand source code
def wrap_with_fsdp(cfg, model: torch.nn.Module,
                   block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP:
    """Wraps a model with FSDP."""
    # Some of the typing is disabled until this gets integrated
    # into the stable version of PyTorch.
    from torch.distributed.fsdp.wrap import ModuleWrapPolicy  # type: ignore

    # we import this here to prevent circular import.
    from ..modules.transformer import StreamingTransformerLayer
    from ..modules.conditioners import ConditioningProvider

    _fix_post_backward_hook()

    assert cfg.use
    sharding_strategy_dict = {
        "no_shard": ShardingStrategy.NO_SHARD,
        "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
        "full_shard": ShardingStrategy.FULL_SHARD,
    }

    dtype_dict = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }

    mixed_precision_config = MixedPrecision(
        param_dtype=dtype_dict[cfg.param_dtype],
        reduce_dtype=dtype_dict[cfg.reduce_dtype],
        buffer_dtype=dtype_dict[cfg.buffer_dtype],
    )

    sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy]
    # The following is going to require being a bit smart
    # when doing LM, because this would flush the weights for every time step
    # during generation. One possiblity is to use hybrid sharding:
    # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy
    assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
        "Not supported at the moment, requires a bit more work."

    local_rank = dora.distrib.get_distrib_spec().local_rank
    assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!"

    auto_wrap_policy = None
    if block_classes is None:
        block_classes = {StreamingTransformerLayer, ConditioningProvider}
    if cfg.per_block:
        auto_wrap_policy = ModuleWrapPolicy(block_classes)
    wrapped = _FSDPFixStateDict(
        model,
        sharding_strategy=sharding_strategy_config,
        mixed_precision=mixed_precision_config,
        device_id=local_rank,
        sync_module_states=True,
        use_orig_params=True,
        auto_wrap_policy=auto_wrap_policy,
    )  # type: ignore
    FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT)  # type: ignore

    # Let the wrapped model know about the wrapping!
    # We use __dict__ to avoid it going into the state dict.
    # This is a bit dirty, but needed during generation, as otherwise
    # the wrapped model would call itself and bypass FSDP.
    for module in FSDP.fsdp_modules(wrapped):
        original = module._fsdp_wrapped_module
        original.__dict__['_fsdp'] = module
    return wrapped

Wraps a model with FSDP.