graph TD
A[Trainer] -->|uses| B[CheckpointManager]
B -->|saves| C[Model State]
B -->|saves| D[Optimizer State]
B -->|saves| E[Training Metadata]
B -->|manages| F[Checkpoint Files]
G[Model Loader] -->|loads| B
The fairseq2.checkpoint.manager.CheckpointManager provides a transactional API for saving checkpoints:
# Initialize checkpoint managerckpt_manager=FileCheckpointManager(checkpoint_dir=Path("checkpoints"),gangs=root_gang,# For distributed training coordinationfile_system=file_system,# File system abstractiontensor_loader=tensor_loader,# For loading tensorstensor_dumper=tensor_dumper,# For saving tensors)# Begin checkpoint operationckpt_manager.begin_checkpoint(step_nr=1000)# Save model and optimizer stateckpt_manager.save_state({"model":model.state_dict(),"optimizer":optimizer.state_dict(),"step_nr":1000,"epoch":5},replicated_keys={"epoch"}# Keys that are same across all processes)# Save validation score if neededckpt_manager.save_score(valid_score,lower_better=True)# Optional, lower is better# Commit the checkpointckpt_manager.commit_checkpoint()
try:# Load the last checkpointstep_nr,state=ckpt_manager.load_last_checkpoint()# Restore model and optimizer statemodel.load_state_dict(state["model"])optimizer.load_state_dict(state["optimizer"])print(f"Restored checkpoint from step {step_nr}")exceptCheckpointNotFoundError:print("No checkpoint found, starting fresh")
checkpoint_dir/
├── model.yaml # Model metadata
├── cc/ # Carbon copy directory for files to copy to each checkpoint
└── step_1000/ # Checkpoint at step 1000
├── model.pt # Model training state
├── rank_0.pt # Process-specific state for rank 0
├── rank_1.pt # Process-specific state for rank 1
└── score.txt # Optional validation score
For tensor parallel training, model files are suffixed with the TP rank:
checkpoint_dir/
├── model.yaml
└── step_1000/
├── model.0.pt # Model shard for TP rank 0
├── model.1.pt # Model shard for TP rank 1
├── replicated.0.pt # Replicated state for TP rank 0
├── replicated.1.pt # Replicated state for TP rank 1
└── score.txt