Module audiocraft.optim.fsdp
Wrapper around FSDP for more convenient use in the training loops.
Functions
def is_fsdp_used() ‑> bool
-
Expand source code
def is_fsdp_used() -> bool: """Return whether we are using FSDP.""" # A bit of a hack but should work from anywhere. if dora.is_xp(): cfg = dora.get_xp().cfg if hasattr(cfg, 'fsdp'): return cfg.fsdp.use return False
Return whether we are using FSDP.
def is_sharded_tensor(x: Any) ‑> bool
-
Expand source code
def is_sharded_tensor(x: tp.Any) -> bool: return isinstance(x, ShardedTensor)
def purge_fsdp(model: torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel)
-
Expand source code
def purge_fsdp(model: FSDP): """Purge the FSDP cached shard inside the model. This should allow setting the best state or switching to the EMA. """ from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore for module in FSDP.fsdp_modules(model): if hasattr(module, "_handles"): # support for FSDP with torch<2.1.0 handles = module._handles if not handles: continue handle = handles[0] unsharded_flat_param = handle._get_padded_unsharded_flat_param() storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore if storage_size == 0: continue true_list = [True for h in handles] _reshard(module, handles, true_list) else: handle = module._handle if not handle: continue unsharded_flat_param = handle._get_padded_unsharded_flat_param() storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore if storage_size == 0: continue _reshard(module, handle, True)
Purge the FSDP cached shard inside the model. This should allow setting the best state or switching to the EMA.
def switch_to_full_state_dict(models: List[torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel])
-
Expand source code
@contextmanager def switch_to_full_state_dict(models: tp.List[FSDP]): # Another bug in FSDP makes it that we cannot use the `state_dict_type` API, # so let's do thing manually. for model in models: FSDP.set_state_dict_type( # type: ignore model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)) try: yield finally: for model in models: FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore
def wrap_with_fsdp(cfg,
model: torch.nn.modules.module.Module,
block_classes: Set[Type] | None = None) ‑> torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel-
Expand source code
def wrap_with_fsdp(cfg, model: torch.nn.Module, block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP: """Wraps a model with FSDP.""" # Some of the typing is disabled until this gets integrated # into the stable version of PyTorch. from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore # we import this here to prevent circular import. from ..modules.transformer import StreamingTransformerLayer from ..modules.conditioners import ConditioningProvider _fix_post_backward_hook() assert cfg.use sharding_strategy_dict = { "no_shard": ShardingStrategy.NO_SHARD, "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP, "full_shard": ShardingStrategy.FULL_SHARD, } dtype_dict = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, } mixed_precision_config = MixedPrecision( param_dtype=dtype_dict[cfg.param_dtype], reduce_dtype=dtype_dict[cfg.reduce_dtype], buffer_dtype=dtype_dict[cfg.buffer_dtype], ) sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy] # The following is going to require being a bit smart # when doing LM, because this would flush the weights for every time step # during generation. One possiblity is to use hybrid sharding: # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \ "Not supported at the moment, requires a bit more work." local_rank = dora.distrib.get_distrib_spec().local_rank assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!" auto_wrap_policy = None if block_classes is None: block_classes = {StreamingTransformerLayer, ConditioningProvider} if cfg.per_block: auto_wrap_policy = ModuleWrapPolicy(block_classes) wrapped = _FSDPFixStateDict( model, sharding_strategy=sharding_strategy_config, mixed_precision=mixed_precision_config, device_id=local_rank, sync_module_states=True, use_orig_params=True, auto_wrap_policy=auto_wrap_policy, ) # type: ignore FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore # Let the wrapped model know about the wrapping! # We use __dict__ to avoid it going into the state dict. # This is a bit dirty, but needed during generation, as otherwise # the wrapped model would call itself and bypass FSDP. for module in FSDP.fsdp_modules(wrapped): original = module._fsdp_wrapped_module original.__dict__['_fsdp'] = module return wrapped
Wraps a model with FSDP.