Trainer

The fairseq2.recipes.trainer.Trainer class is the core class for training models.

Overview

The trainer in fairseq2 is designed to be flexible and model-agnostic, handling various training scenarios from simple models to complex distributed training setups. It is probably the most complex system in fairseq2, but also the most powerful.

        flowchart LR
    %% Main Trainer Class
    A[Trainer] --> B[TrainUnit]
    A --> C[DataReader]
    A --> D[Optimizer]
    A --> E[CheckpointManager]
    A --> H[LRScheduler]
    A --> I[Gang System]
    A --> P[Metrics Logging]
    A --> V[Validation]

    %% TrainUnit Components
    B --> F[Model]
    B --> G[MetricBag]

    %% Gang System
    I --> J[Root Gang]
    I --> K[DP Gang]
    I --> L[TP Gang]

    %% Metrics Logging
    P --> P1[TensorBoard]
    P --> P2[WandB]
    P --> P3[JSON Logger]

    %% Validation
    V --> Q[EvalUnit]
    V --> R[Validation DataReader]

    %% CheckpointManager Details
    E --> E1[Save State]
    E --> E2[Load State]
    E --> E3[Keep Best Checkpoints]
    E --> E4[Save FSDP Model]
    

Core Components

TrainUnit

The TrainUnit is an abstract class that encapsulates model-specific training logic:

class TrainUnit(ABC, Generic[BatchT_contra]):
    """Represents a unit to be used with Trainer."""

    @abstractmethod
    def __call__(self, batch: BatchT_contra) -> tuple[Tensor, int | None]:
        """Process batch and return loss and number of targets."""

    @abstractmethod
    def set_step_nr(self, step_nr: int) -> None:
        """Set current training step number."""

    @property
    @abstractmethod
    def model(self) -> Module:
        """The underlying model."""

    @property
    @abstractmethod
    def metric_bag(self) -> MetricBag:
        """Training-related metrics."""
Example implementation
class TransformerTrainUnit(TrainUnit[TransformerBatch]):
  def __init__(self, model: TransformerModel) -> None:
      super().__init__(model)
      self._metric_bag = MetricBag()
      self._metric_bag.register_metric("loss", Mean())

  def __call__(self, batch: TransformerBatch) -> tuple[Tensor, int]:
      outputs = self._model(**batch)
      return outputs.loss, batch.num_tokens

Trainer Configuration

The fairseq2.recipes.trainer.Trainer class accepts a wide range of configuration options:

# Example Trainer Configuration
trainer = Trainer(
    # Basic parameters
    unit=train_unit,                     # Training unit to compute loss
    data_reader=data_reader,             # Data reader for training batches
    optimizer=optimizer,                 # Optimizer
    checkpoint_manager=checkpoint_mgr,   # Checkpoint manager
    root_gang=root_gang,                 # Root gang for distributed training

    # Optional parameters
    dp_gang=dp_gang,                     # Data parallel gang
    tp_gang=tp_gang,                     # Tensor parallel gang
    dtype=torch.float32,                 # Model data type
    lr_scheduler=lr_scheduler,           # Learning rate scheduler
    max_num_steps=100_000,               # Maximum training steps
    max_num_data_epochs=10,              # Maximum training epochs

    # Validation parameters
    valid_units=[valid_unit],            # Validation units
    valid_data_readers=[valid_reader],   # Validation data readers
    validate_every_n_steps=1_000,        # Validation frequency

    # Checkpoint parameters
    checkpoint_every_n_steps=5_000,      # Checkpoint frequency
    keep_last_n_checkpoints=5,           # Number of checkpoints to keep

    # Metric parameters
    publish_metrics_every_n_steps=100,   # Metric publishing frequency
    tb_dir=Path("runs"),                 # TensorBoard directory
    metrics_dir=Path("metrics"),         # Metrics directory
)

Training Flow

The training process follows this simplified sequence:

        sequenceDiagram
    participant T as Trainer
    participant U as TrainUnit
    participant D as DataReader
    participant M as Model
    participant O as Optimizer

    T->>D: Request batch
    D-->>T: Return batch
    T->>U: Process batch
    U->>M: Forward pass
    M-->>U: Return loss
    U-->>T: Return loss, num_targets
    T->>M: Backward pass
    T->>O: Update parameters
    T->>T: Update metrics
    
Step-by-step breakdown

We provide a simplified step-by-step process for the trainer in the following code snippet to help you understand the training flow.

  1. Initialization: The trainer is initialized with the necessary components and configurations.

def __init__(self, unit: TrainUnit[BatchT], data_reader: DataReader[BatchT], ...):
    self._model = unit.model
    self._unit = unit
    self._data_reader = data_reader
    # ... initialize other components
  1. Training Loop: The training loop is implemented in the _do_run method:

def _do_run(self) -> None:
    while self._should_run_step():
        self._step_nr += 1

        # Run training step
        self._run_step()

        # Maybe validate
        if self._should_validate():
            self._validate()

        # Maybe checkpoint
        if self._should_checkpoint():
            self._checkpoint()
  1. Step Execution: The _run_step method is responsible for executing a single training step:

def _run_step(self) -> None:
    # Collect batches
    batches = self._next_batches()

    # Process each batch
    for batch in batches:
        # Forward pass
        loss, num_targets = self._unit(batch)

        # Backward pass
        self._loss_scaler.backward(loss)

        # Update parameters
        self._loss_scaler.run_optimizer_step(self._step_nr)
  1. Validation: The validation loop is implemented in the _validate method:

def _validate(self) -> None:
    log.info("Starting validation after step {}.", self._step_nr)

    self._model.eval()

    with summon_fsdp_for_validation(self._model):
        unit_scores = []

        for unit, data_reader in zip(self._valid_units, self._valid_data_readers):
            unit_score = self._validate_unit(unit, data_reader)
            if unit_score is not None:
                unit_scores.append(unit_score)

        self._valid_score = self._compute_valid_score(unit_scores)

    self._model.train()

    log.info("Validation complete.")
  • Validation occurs at specified intervals (steps or epochs).

  • e.g.: validate_every_n_steps or validate_every_n_data_epochs

  • It computes a score (like accuracy) using fairseq2.recipes.evaluator.EvalUnit objects and logs metrics.

  • The validation score is compared to previous scores to:
    • Save the best checkpoints.

    • Stop early if performance stagnates.

  1. Checkpoint: The checkpointing logic is implemented in the _checkpoint method:

def _checkpoint(self) -> None:
    # Save checkpoint
    step_nr = self._step_nr

    self._checkpoint_manager.begin_checkpoint(step_nr)

    self._checkpoint_manager.save_state(
        self.state_dict(), model_key="_model"
    )
  • The trainer saves checkpoints periodically at specified intervals (steps or epochs):
    • e.g.: checkpoint_every_n_steps or checkpoint_every_n_data_epochs

    • Both model weights and optimizer state are saved.

    • Best-performing models are saved based on the validation score.

    • e.g.: keep_best_n_checkpoints=3

  • The checkpoint manager handles the checkpoint saving and loading, which ensures:
    • Training can be resumed after interruptions.

    • Best models are preserved for deployment.

  1. Metrics Logging: The metrics logging logic is implemented in the _publish_metrics method:

def _publish_metrics(self) -> None:
    if self._tp_gang.rank == 0:
        values = self._metric_bag.sync_and_compute_metrics()
        record_metrics(self._metric_recorders, "train", values, self._step_nr)
  • The trainer supports multiple logging backends:
    • TensorBoard: Visualize training curves
      • tb_dir = Path("logs/tb")

    • JSON Logs: Store metrics in files
      • metrics_dir = Path("logs/metrics")

    • Weights & Biases (WandB): Collaborative logging
      • wandb_options = (Path("logs/wandb"), "project_name", "run_name")

Best Practices

  1. Metric Tracking:

    • Register all relevant metrics in the train unit

    • Use appropriate metric types (Mean, Sum, etc.)

    • Consider adding validation metrics

  2. Resource Management:

    • Use appropriate batch sizes for your hardware

    • Enable amp for memory efficiency

    • Configure gradient accumulation as needed

  3. Checkpoint Management:

    • Save checkpoints regularly

    • Implement proper cleanup strategy

  4. Validation:

    • Validate at appropriate intervals

    • Track relevant validation metrics

    • Implement early stopping if needed

Advanced Features

  1. Early Stopping:

    def early_stopper(step_nr: int, score: float) -> bool:
        # Custom early stopping logic
        return score < threshold
    
    trainer = Trainer(
        early_stopper=early_stopper,
        score_metric_name="validation_loss",
        lower_better=True,
    )
    
  2. Custom Learning Rate Scheduling:

    class CustomLRScheduler(LRScheduler):
        def get_lr(self) -> float:
            # Custom LR calculation
            return self.base_lr * decay_factor(self.step_nr)
    
    trainer = Trainer(
        lr_scheduler=CustomLRScheduler(optimizer),
    )
    
  3. Profiling:

    trainer = Trainer(
        profile=(100, 10),  # Skip 100 steps, profile 10 steps
        tb_dir=Path("logs/tb"),  # Save profiles to TensorBoard
    )