graph TD
A[Trainer] -->|uses| B[CheckpointManager]
B -->|saves| C[Model State]
B -->|saves| D[Optimizer State]
B -->|saves| E[Training Metadata]
B -->|manages| F[Checkpoint Files]
G[Model Loader] -->|loads| B
The fairseq2.checkpoint.manager.CheckpointManager provides a transactional API for saving checkpoints:
# Initialize checkpoint managerckpt_manager=FileCheckpointManager(checkpoint_dir=Path("checkpoints"),gang=root_gang# For distributed training coordination)# Begin checkpoint operationckpt_manager.begin_checkpoint(step_nr=1000)# Save model and optimizer stateckpt_manager.save_state({"model":model.state_dict(),"optimizer":optimizer.state_dict(),"step_nr":1000,"epoch":5})# Save validation score if neededckpt_manager.save_score(valid_score)# Commit the checkpointckpt_manager.commit_checkpoint()
try:# Load the last checkpointstep_nr,state=ckpt_manager.load_last_checkpoint()# Restore model and optimizer statemodel.load_state_dict(state["model"])optimizer.load_state_dict(state["optimizer"])print(f"Restored checkpoint from step {step_nr}")exceptCheckpointNotFoundError:print("No checkpoint found, starting fresh")
checkpoint_dir/
├── model.yaml # Model metadata
└── step_1000/ # Checkpoint at step 1000
└── model.pt # Model training state
For sharded checkpoints (FSDP), each rank has its own files:
checkpoint_dir/
├── model.yaml # Model metadata
└── step_1000/
├── model.pt # Consolidated model
├── rank_0.pt # Model rank 0 state
└── rank_1.pt # Model rank 1 state