fairseq2.model_checkpoint¶
This module provides a memory-efficient model checkpoint loading API that supports lazy loading of various checkpoint formats while also supporting distributed configurations with tensor resharding capability.
The loaders support:
Memory-efficient lazy loading to avoid loading entire checkpoints into memory at once if the underlying format allows it. In particular relevant for large checkpoints that may not fit entirely into memory.
On-the-fly tensor resharding across different distributed configurations.
Optional memory mapping for reduced memory footprint.
State dict conversion for format compatibility.
Automatic format detection.
from fairseq2.model_checkpoint import get_model_checkpoint_loader
from fairseq2.nn import get_shard_dims
model = ... # PyTorch Module
checkpoint_path = ... # Checkpoint file
# Get shard dimensions of each parameter of the model.
shard_dims = get_shard_dims(model)
# Load checkpoint.
for key, tensor in loader.lazy_load(checkpoint_path, shard_dims):
# Process each tensor lazily without loading entire checkpoint.
ABCs¶
- class fairseq2.model_checkpoint.ModelCheckpointLoader[source]¶
Bases:
ABCRepresents the abstract base class for model checkpoint loaders.
This class defines the interface for checkpoint loaders that can efficiently load model state by yielding parameters lazily rather than loading everything into memory at once.
- abstract lazy_load(path: Path, shard_dims: Mapping[str, int], options: ModelCheckpointLoadOptions | None = None) Iterator[tuple[str, Tensor]][source]¶
Lazily loads parameters from the specified checkpoint path.
Yields tensors one at a time to minimize memory usage if the underlying format allows it. Supports tensor resharding and optional state dictionary conversion.
If
shard_dimsis provided, it specifies the sharding dimension of each parameter as returned byget_sharding_dims(). Along withgangs, they enable on-the-fly parameter resharding during checkpoint loading.Yields pairs of
(parameter name, parameter)for each parameter in the checkpoint.- Raises:
CorruptModelCheckpointError – Checkpoint is erroneous and cannot be loaded.
OSError – A system error occurred.
Classes¶
- class fairseq2.model_checkpoint.ModelCheckpointLoadOptions(*, gangs: 'Gangs | None' = None, mmap: 'bool' = False, restrict: 'bool' = True, state_dict_converter: 'StateDictConverter | None' = None)[source]¶
Bases:
object- gangs: Gangs | None = None¶
Used to determine the distributed target configuration and shard yielded parameters accordingly. If
None, the gangs returned fromget_current_gangs()will be used.
- mmap: bool = False¶
Indicates whether the checkpoint will be memory-mapped. This can reduce memory usage but may cause slower load times on some systems.
Functions¶
- fairseq2.model_checkpoint.get_model_checkpoint_loader() ModelCheckpointLoader[source]¶
- fairseq2.model_checkpoint.reshard_tensor(key: str, source_splits: list[list[Tensor]], source_shard_sizes: tuple[int, int], target_shard_sizes: tuple[int, int], target_shard_ranks: tuple[int, int], shard_dims: Mapping[str, int]) Tensor[source]¶
Reshards a parameter tensor from a distributed source configuration to a target configuration.
This function is meant for authors of new
ModelCheckpointLoaderimplementations. It handles the complex task of resharding tensors when loading checkpoints from one distributed configuration (e.g. 4-way tensor parallelism) to a different target configuration (e.g. 8-way tensor parallelism). It efficiently concatenates and slices tensors to produce the correct shards for the target rank.The resharding process involves:
Determining if the tensor requires tensor parallelism based on specified shard dimensions.
For tensor parallel tensors, concatenating source shards and re-slicing for the target configuration in a memory-efficient way.
For replicated tensors, concatenating data parallel splits.
keyspecifies the name of the parameter to retrieve its sharding information fromshard_dims. Seeget_sharding_dims()for more information.source_splitsis a 2D list structure[tp_idx][dp_idx]containing the source tensor shards. The outer list specifies tensor parallel shards and inner lists specify data parallel shards.source_shard_sizesandtarget_shard_sizesspecify the distributed source and target configurations respectively in the form of(tp_size, dp_size).target_shard_ranksspecifies the ranks of the current process in the target configuration in the form of(tp_rank, dp_rank).shard_dimsspecifies the mapping from parameter names to dimensions along which parameters should be sharded for tensor parallelism. Omitted for replicated tensors. Seeget_sharding_dims()for more information.Returns the resharded tensor for the target rank and configuration.
Resharding from 2-way TP to 4-way TP¶param_name = "model.weight" # 2 TP shards with 1 DP shard each. source_splits = [[tensor_tp0_dp0], [tensor_tp1_dp0]] source_shard_sizes = (2, 1) # 2-way TP, 1-way DP target_shard_sizes = (4, 1) # 4-way TP, 1-way DP target_shard_ranks = (2, 0) # Want shard for TP rank 2 # For a tensor with TP dim=0, this will concatenate the 2 source shards # and slice out the portion corresponding to TP rank 2 in 4-way setup resharded = reshard_tensor( param_name, source_splits, source_shard_sizes, target_shard_sizes, target_shard_ranks, shard_dims={param_name: 0}, )
Note
This function deletes intermediate tensors during the resharding process to minimize peak memory usage.