fairseq2.model_checkpoint

This module provides a memory-efficient model checkpoint loading API that supports lazy loading of various checkpoint formats while also supporting distributed configurations with tensor resharding capability.

The loaders support:

  • Memory-efficient lazy loading to avoid loading entire checkpoints into memory at once if the underlying format allows it. In particular relevant for large checkpoints that may not fit entirely into memory.

  • On-the-fly tensor resharding across different distributed configurations.

  • Optional memory mapping for reduced memory footprint.

  • State dict conversion for format compatibility.

  • Automatic format detection.

Example Usage
from fairseq2.model_checkpoint import get_model_checkpoint_loader
from fairseq2.nn import get_shard_dims

model = ... # PyTorch Module

checkpoint_path = ...  # Checkpoint file

# Get shard dimensions of each parameter of the model.
shard_dims = get_shard_dims(model)

# Load checkpoint.
for key, tensor in loader.lazy_load(checkpoint_path, shard_dims):
    # Process each tensor lazily without loading entire checkpoint.

ABCs

class fairseq2.model_checkpoint.ModelCheckpointLoader[source]

Bases: ABC

Represents the abstract base class for model checkpoint loaders.

This class defines the interface for checkpoint loaders that can efficiently load model state by yielding parameters lazily rather than loading everything into memory at once.

abstract lazy_load(path: Path, shard_dims: Mapping[str, int], options: ModelCheckpointLoadOptions | None = None) Iterator[tuple[str, Tensor]][source]

Lazily loads parameters from the specified checkpoint path.

Yields tensors one at a time to minimize memory usage if the underlying format allows it. Supports tensor resharding and optional state dictionary conversion.

If shard_dims is provided, it specifies the sharding dimension of each parameter as returned by get_sharding_dims(). Along with gangs, they enable on-the-fly parameter resharding during checkpoint loading.

Yields pairs of (parameter name, parameter) for each parameter in the checkpoint.

Raises:
abstract supports_path(path: Path) bool[source]

Checks if this loader can handle the specified checkpoint path.

Raises:

OSError – A system error occurred.

Classes

class fairseq2.model_checkpoint.ModelCheckpointLoadOptions(*, gangs: 'Gangs | None' = None, mmap: 'bool' = False, restrict: 'bool' = True, state_dict_converter: 'StateDictConverter | None' = None)[source]

Bases: object

gangs: Gangs | None = None

Used to determine the distributed target configuration and shard yielded parameters accordingly. If None, the gangs returned from get_current_gangs() will be used.

mmap: bool = False

Indicates whether the checkpoint will be memory-mapped. This can reduce memory usage but may cause slower load times on some systems.

restrict: bool = True

Indicates whether unpickler (if used) will be restricted to load only tensors and types that can be safely serialized and deserialized.

state_dict_converter: StateDictConverter | None = None

If provided, used to transform the (sharded) state dictionaries in the checkpoint from one format, such as Hugging Face Transformers, to fairseq2.

Functions

fairseq2.model_checkpoint.get_model_checkpoint_loader() ModelCheckpointLoader[source]
fairseq2.model_checkpoint.reshard_tensor(key: str, source_splits: list[list[Tensor]], source_shard_sizes: tuple[int, int], target_shard_sizes: tuple[int, int], target_shard_ranks: tuple[int, int], shard_dims: Mapping[str, int]) Tensor[source]

Reshards a parameter tensor from a distributed source configuration to a target configuration.

This function is meant for authors of new ModelCheckpointLoader implementations. It handles the complex task of resharding tensors when loading checkpoints from one distributed configuration (e.g. 4-way tensor parallelism) to a different target configuration (e.g. 8-way tensor parallelism). It efficiently concatenates and slices tensors to produce the correct shards for the target rank.

The resharding process involves:

  1. Determining if the tensor requires tensor parallelism based on specified shard dimensions.

  2. For tensor parallel tensors, concatenating source shards and re-slicing for the target configuration in a memory-efficient way.

  3. For replicated tensors, concatenating data parallel splits.

key specifies the name of the parameter to retrieve its sharding information from shard_dims. See get_sharding_dims() for more information.

source_splits is a 2D list structure [tp_idx][dp_idx] containing the source tensor shards. The outer list specifies tensor parallel shards and inner lists specify data parallel shards.

source_shard_sizes and target_shard_sizes specify the distributed source and target configurations respectively in the form of (tp_size, dp_size).

target_shard_ranks specifies the ranks of the current process in the target configuration in the form of (tp_rank, dp_rank).

shard_dims specifies the mapping from parameter names to dimensions along which parameters should be sharded for tensor parallelism. Omitted for replicated tensors. See get_sharding_dims() for more information.

Returns the resharded tensor for the target rank and configuration.

Resharding from 2-way TP to 4-way TP
param_name = "model.weight"

# 2 TP shards with 1 DP shard each.
source_splits = [[tensor_tp0_dp0], [tensor_tp1_dp0]]

source_shard_sizes = (2, 1)  # 2-way TP, 1-way DP
target_shard_sizes = (4, 1)  # 4-way TP, 1-way DP

target_shard_ranks = (2, 0)  # Want shard for TP rank 2

# For a tensor with TP dim=0, this will concatenate the 2 source shards
# and slice out the portion corresponding to TP rank 2 in 4-way setup
resharded = reshard_tensor(
    param_name,
    source_splits,
    source_shard_sizes,
    target_shard_sizes,
    target_shard_ranks,
    shard_dims={param_name: 0},
)

Note

This function deletes intermediate tensors during the resharding process to minimize peak memory usage.

Exceptions

class fairseq2.model_checkpoint.CorruptModelCheckpointError(path: Path, message: str)[source]

Bases: Exception