The trainer in fairseq2 is designed to be flexible and model-agnostic, handling various training scenarios from simple models to complex distributed training setups.
It is probably the most complex system in fairseq2, but also the most powerful.
flowchart LR
%% Main Trainer Class
A[Trainer] --> B[TrainUnit]
A --> C[DataReader]
A --> D[Optimizer]
A --> E[CheckpointManager]
A --> H[LRScheduler]
A --> I[Gang System]
A --> P[Metrics Logging]
A --> V[Validation]
%% TrainUnit Components
B --> F[Model]
B --> G[MetricBag]
%% Gang System
I --> J[Root Gang]
I --> K[DP Gang]
I --> L[TP Gang]
%% Metrics Logging
P --> P1[TensorBoard]
P --> P2[WandB]
P --> P3[JSON Logger]
%% Validation
V --> Q[EvalUnit]
V --> R[Validation DataReader]
%% CheckpointManager Details
E --> E1[Save State]
E --> E2[Load State]
E --> E3[Keep Best Checkpoints]
E --> E4[Save FSDP Model]
The TrainUnit is an abstract class that encapsulates model-specific training logic:
classTrainUnit(ABC,Generic[BatchT_contra]):"""Represents a unit to be used with Trainer."""@abstractmethoddef__call__(self,batch:BatchT_contra)->tuple[Tensor,int|None]:"""Process batch and return loss and number of targets."""@abstractmethoddefset_step_nr(self,step_nr:int)->None:"""Set current training step number."""@property@abstractmethoddefmodel(self)->Module:"""The underlying model."""@property@abstractmethoddefmetric_bag(self)->MetricBag:"""Training-related metrics."""
# Example Trainer Configurationtrainer=Trainer(# Basic parametersunit=train_unit,# Training unit to compute lossdata_reader=data_reader,# Data reader for training batchesoptimizer=optimizer,# Optimizercheckpoint_manager=checkpoint_mgr,# Checkpoint managerroot_gang=root_gang,# Root gang for distributed training# Optional parametersdp_gang=dp_gang,# Data parallel gangtp_gang=tp_gang,# Tensor parallel gangdtype=torch.float32,# Model data typelr_scheduler=lr_scheduler,# Learning rate schedulermax_num_steps=100_000,# Maximum training stepsmax_num_data_epochs=10,# Maximum training epochs# Validation parametersvalid_units=[valid_unit],# Validation unitsvalid_data_readers=[valid_reader],# Validation data readersvalidate_every_n_steps=1_000,# Validation frequency# Checkpoint parameterscheckpoint_every_n_steps=5_000,# Checkpoint frequencykeep_last_n_checkpoints=5,# Number of checkpoints to keepkeep_best_n_checkpoints=3,# Number of best checkpoints to keepkeep_last_n_models=5,# Number of models to keepkeep_best_n_models=3,# Number of best models to keep# Metric parameterspublish_metrics_every_n_steps=100,# Metric publishing frequencytb_dir=Path("runs"),# TensorBoard directorymetrics_dir=Path("metrics"),# Metrics directory# Advanced parametersfp16_loss_scale=(128.0,0.0001),# Initial and min loss scale for fp16max_gradient_norm=None,# Max gradient norm for clippingamp=False,# Enable automatic mixed precisionanomaly_detection=False,# Enable autograd anomaly detectionseed=2# Random seed)
The training process follows this simplified sequence:
sequenceDiagram
participant T as Trainer
participant U as TrainUnit
participant D as DataReader
participant M as Model
participant O as Optimizer
T->>D: Request batch
D-->>T: Return batch
T->>U: Process batch
U->>M: Forward pass
M-->>U: Return loss
U-->>T: Return loss, num_targets
T->>M: Backward pass
T->>O: Update parameters
T->>T: Update metrics
Step-by-step breakdown
We provide a simplified step-by-step process for the trainer in the following code snippet to help you understand the training flow.
Initialization: The trainer is initialized with the necessary components and configurations.
def__init__(self,unit:TrainUnit[BatchT],data_reader:DataReader[BatchT],...):self._model=unit.modelself._unit=unitself._data_reader=data_reader# ... initialize other components
Training Loop: The training loop is implemented in the _do_run method:
def_do_run(self)->None:whileself._should_run_step():self._step_nr+=1# Run training stepself._run_step()# Maybe validateifself._should_validate():self._validate()# Maybe checkpointifself._should_checkpoint():self._checkpoint()
Step Execution: The _run_step method is responsible for executing a single training step:
def_run_step(self)->None:# Collect batchesbatches=self._next_batches()# Process each batchforbatchinbatches:# Forward passloss,num_targets=self._unit(batch)# Backward passself._loss_scaler.backward(loss)# Update parametersself._loss_scaler.run_optimizer_step(self._step_nr)
Validation: The validation loop is implemented in the _validate method:
Metric Tracking:
- Register all relevant metrics in the train unit
- Use appropriate metric types (Mean, Sum, etc.)
- Consider adding validation metrics
Resource Management:
- Use appropriate batch sizes for your hardware
- Enable amp for memory efficiency
- Configure gradient accumulation as needed
Checkpoint Management:
- Save checkpoints regularly
- Use both keep_last_n_checkpoints and keep_best_n_checkpoints
- Consider separate policies for full checkpoints vs models
Validation:
- Validate at appropriate intervals
- Track relevant validation metrics
- Implement early stopping if needed
defearly_stopper(step_nr:int,score:float)->bool:# Custom early stopping logicreturnscore<thresholdmetric_descriptors=get_runtime_context().get_registry(MetricDescriptor)try:score_metric_descriptor=metric_descriptors.get(metric_name)exceptLookupError:raiseUnknownMetricDescriptorError(metric_name)fromNonetrainer=Trainer(early_stopper=early_stopper,score_metric_descriptor=score_metric_descriptor,lower_better=True,)
Custom Learning Rate Scheduling:
classCustomLRScheduler(LRScheduler):defget_lr(self)->float:# Custom LR calculationreturnself.base_lr*decay_factor(self.step_nr)trainer=Trainer(lr_scheduler=CustomLRScheduler(optimizer),)