The trainer in fairseq2 is designed to be flexible and model-agnostic, handling various training scenarios from simple models to complex distributed training setups.
It is probably the most complex system in fairseq2, but also the most powerful.
flowchart LR
%% Main Trainer Class
A[Trainer] --> B[TrainUnit]
A --> C[DataReader]
A --> D[Optimizer]
A --> E[CheckpointManager]
A --> H[LRScheduler]
A --> I[Gang System]
A --> P[Metrics Logging]
A --> V[Validation]
%% TrainUnit Components
B --> F[Model]
B --> G[MetricBag]
%% Gang System
I --> J[Root Gang]
I --> K[DP Gang]
I --> L[TP Gang]
%% Metrics Logging
P --> P1[TensorBoard]
P --> P2[WandB]
P --> P3[JSON Logger]
%% Validation
V --> Q[EvalUnit]
V --> R[Validation DataReader]
%% CheckpointManager Details
E --> E1[Save State]
E --> E2[Load State]
E --> E3[Keep Best Checkpoints]
E --> E4[Save FSDP Model]
The TrainUnit is an abstract class that encapsulates model-specific training logic:
classTrainUnit(ABC,Generic[BatchT_contra]):"""Represents a unit to be used with Trainer."""@abstractmethoddef__call__(self,batch:BatchT_contra)->tuple[Tensor,int|None]:"""Process batch and return loss and number of targets."""@abstractmethoddefset_step_nr(self,step_nr:int)->None:"""Set current training step number."""@property@abstractmethoddefmodel(self)->Module:"""The underlying model."""@property@abstractmethoddefmetric_bag(self)->MetricBag:"""Training-related metrics."""
# Example Trainer Configurationtrainer=Trainer(# Basic parametersunit=train_unit,# Training unit to compute lossdata_reader=data_reader,# Data reader for training batchesoptimizer=optimizer,# Optimizercheckpoint_manager=checkpoint_mgr,# Checkpoint managerroot_gang=root_gang,# Root gang for distributed training# Optional parametersdp_gang=dp_gang,# Data parallel gangtp_gang=tp_gang,# Tensor parallel gangdtype=torch.float32,# Model data typelr_scheduler=lr_scheduler,# Learning rate schedulermax_num_steps=100_000,# Maximum training stepsmax_num_data_epochs=10,# Maximum training epochs# Validation parametersvalid_units=[valid_unit],# Validation unitsvalid_data_readers=[valid_reader],# Validation data readersvalidate_every_n_steps=1_000,# Validation frequency# Checkpoint parameterscheckpoint_every_n_steps=5_000,# Checkpoint frequencykeep_last_n_checkpoints=5,# Number of checkpoints to keep# Metric parameterspublish_metrics_every_n_steps=100,# Metric publishing frequencytb_dir=Path("runs"),# TensorBoard directorymetrics_dir=Path("metrics"),# Metrics directory)
The training process follows this simplified sequence:
sequenceDiagram
participant T as Trainer
participant U as TrainUnit
participant D as DataReader
participant M as Model
participant O as Optimizer
T->>D: Request batch
D-->>T: Return batch
T->>U: Process batch
U->>M: Forward pass
M-->>U: Return loss
U-->>T: Return loss, num_targets
T->>M: Backward pass
T->>O: Update parameters
T->>T: Update metrics
Step-by-step breakdown
We provide a simplified step-by-step process for the trainer in the following code snippet to help you understand the training flow.
Initialization: The trainer is initialized with the necessary components and configurations.
def__init__(self,unit:TrainUnit[BatchT],data_reader:DataReader[BatchT],...):self._model=unit.modelself._unit=unitself._data_reader=data_reader# ... initialize other components
Training Loop: The training loop is implemented in the _do_run method:
def_do_run(self)->None:whileself._should_run_step():self._step_nr+=1# Run training stepself._run_step()# Maybe validateifself._should_validate():self._validate()# Maybe checkpointifself._should_checkpoint():self._checkpoint()
Step Execution: The _run_step method is responsible for executing a single training step:
def_run_step(self)->None:# Collect batchesbatches=self._next_batches()# Process each batchforbatchinbatches:# Forward passloss,num_targets=self._unit(batch)# Backward passself._loss_scaler.backward(loss)# Update parametersself._loss_scaler.run_optimizer_step(self._step_nr)
Validation: The validation loop is implemented in the _validate method:
def_validate(self)->None:log.info("Starting validation after step {}.",self._step_nr)self._model.eval()withsummon_fsdp_for_validation(self._model):unit_scores=[]forunit,data_readerinzip(self._valid_units,self._valid_data_readers):unit_score=self._validate_unit(unit,data_reader)ifunit_scoreisnotNone:unit_scores.append(unit_score)self._valid_score=self._compute_valid_score(unit_scores)self._model.train()log.info("Validation complete.")
Validation occurs at specified intervals (steps or epochs).
e.g.: validate_every_n_steps or validate_every_n_data_epochs
It computes a score (like accuracy) using fairseq2.recipes.evaluator.EvalUnit objects and logs metrics.
The validation score is compared to previous scores to:
Save the best checkpoints.
Stop early if performance stagnates.
Checkpoint: The checkpointing logic is implemented in the _checkpoint method:
def_checkpoint(self)->None:# Save checkpointstep_nr=self._step_nrself._checkpoint_manager.begin_checkpoint(step_nr)self._checkpoint_manager.save_state(self.state_dict(),model_key="_model")
The trainer saves checkpoints periodically at specified intervals (steps or epochs):
e.g.: checkpoint_every_n_steps or checkpoint_every_n_data_epochs
Both model weights and optimizer state are saved.
Best-performing models are saved based on the validation score.
e.g.: keep_best_n_checkpoints=3
The checkpoint manager handles the checkpoint saving and loading, which ensures:
Training can be resumed after interruptions.
Best models are preserved for deployment.
Metrics Logging: The metrics logging logic is implemented in the _publish_metrics method:
defearly_stopper(step_nr:int,score:float)->bool:# Custom early stopping logicreturnscore<thresholdtrainer=Trainer(early_stopper=early_stopper,score_metric_name="validation_loss",lower_better=True,)
Custom Learning Rate Scheduling:
classCustomLRScheduler(LRScheduler):defget_lr(self)->float:# Custom LR calculationreturnself.base_lr*decay_factor(self.step_nr)trainer=Trainer(lr_scheduler=CustomLRScheduler(optimizer),)
Profiling:
trainer=Trainer(profile=(100,10),# Skip 100 steps, profile 10 stepstb_dir=Path("logs/tb"),# Save profiles to TensorBoard)