Trainer

The fairseq2.recipes.trainer.Trainer class is the core class for training models.

Overview

The trainer in fairseq2 is designed to be flexible and model-agnostic, handling various training scenarios from simple models to complex distributed training setups. It is probably the most complex system in fairseq2, but also the most powerful.

        flowchart LR
    %% Main Trainer Class
    A[Trainer] --> B[TrainUnit]
    A --> C[DataReader]
    A --> D[Optimizer]
    A --> E[CheckpointManager]
    A --> H[LRScheduler]
    A --> I[Gang System]
    A --> P[Metrics Logging]
    A --> V[Validation]

    %% TrainUnit Components
    B --> F[Model]
    B --> G[MetricBag]

    %% Gang System
    I --> J[Root Gang]
    I --> K[DP Gang]
    I --> L[TP Gang]

    %% Metrics Logging
    P --> P1[TensorBoard]
    P --> P2[WandB]
    P --> P3[JSON Logger]

    %% Validation
    V --> Q[EvalUnit]
    V --> R[Validation DataReader]

    %% CheckpointManager Details
    E --> E1[Save State]
    E --> E2[Load State]
    E --> E3[Keep Best Checkpoints]
    E --> E4[Save FSDP Model]
    

Core Components

TrainUnit

The TrainUnit is an abstract class that encapsulates model-specific training logic:

class TrainUnit(ABC, Generic[BatchT_contra]):
    """Represents a unit to be used with Trainer."""

    @abstractmethod
    def __call__(self, batch: BatchT_contra) -> tuple[Tensor, int | None]:
        """Process batch and return loss and number of targets."""

    @abstractmethod
    def set_step_nr(self, step_nr: int) -> None:
        """Set current training step number."""

    @property
    @abstractmethod
    def model(self) -> Module:
        """The underlying model."""

    @property
    @abstractmethod
    def metric_bag(self) -> MetricBag:
        """Training-related metrics."""
Example implementation
class TransformerTrainUnit(TrainUnit[TransformerBatch]):
  def __init__(self, model: TransformerModel) -> None:
      super().__init__(model)
      self._metric_bag = MetricBag()
      self._metric_bag.register_metric("loss", Mean())

  def __call__(self, batch: TransformerBatch) -> tuple[Tensor, int]:
      outputs = self._model(**batch)
      return outputs.loss, batch.num_tokens

Trainer Configuration

The fairseq2.recipes.trainer.Trainer class accepts a wide range of configuration options:

# Example Trainer Configuration
trainer = Trainer(
    # Basic parameters
    unit=train_unit,                     # Training unit to compute loss
    data_reader=data_reader,             # Data reader for training batches
    optimizer=optimizer,                 # Optimizer
    checkpoint_manager=checkpoint_mgr,   # Checkpoint manager
    root_gang=root_gang,                 # Root gang for distributed training

    # Optional parameters
    dp_gang=dp_gang,                     # Data parallel gang
    tp_gang=tp_gang,                     # Tensor parallel gang
    dtype=torch.float32,                 # Model data type
    lr_scheduler=lr_scheduler,           # Learning rate scheduler
    max_num_steps=100_000,               # Maximum training steps
    max_num_data_epochs=10,              # Maximum training epochs

    # Validation parameters
    valid_units=[valid_unit],            # Validation units
    valid_data_readers=[valid_reader],   # Validation data readers
    validate_every_n_steps=1_000,        # Validation frequency

    # Checkpoint parameters
    checkpoint_every_n_steps=5_000,      # Checkpoint frequency
    keep_last_n_checkpoints=5,           # Number of checkpoints to keep
    keep_best_n_checkpoints=3,           # Number of best checkpoints to keep
    keep_last_n_models=5,                # Number of models to keep
    keep_best_n_models=3,                # Number of best models to keep

    # Metric parameters
    publish_metrics_every_n_steps=100,   # Metric publishing frequency
    tb_dir=Path("runs"),                 # TensorBoard directory
    metrics_dir=Path("metrics"),         # Metrics directory

    # Advanced parameters
    fp16_loss_scale=(128.0, 0.0001),    # Initial and min loss scale for fp16
    max_gradient_norm=None,              # Max gradient norm for clipping
    amp=False,                           # Enable automatic mixed precision
    anomaly_detection=False,             # Enable autograd anomaly detection
    seed=2                               # Random seed
)

Training Flow

The training process follows this simplified sequence:

        sequenceDiagram
    participant T as Trainer
    participant U as TrainUnit
    participant D as DataReader
    participant M as Model
    participant O as Optimizer

    T->>D: Request batch
    D-->>T: Return batch
    T->>U: Process batch
    U->>M: Forward pass
    M-->>U: Return loss
    U-->>T: Return loss, num_targets
    T->>M: Backward pass
    T->>O: Update parameters
    T->>T: Update metrics
    
Step-by-step breakdown

We provide a simplified step-by-step process for the trainer in the following code snippet to help you understand the training flow.

  1. Initialization: The trainer is initialized with the necessary components and configurations.

def __init__(self, unit: TrainUnit[BatchT], data_reader: DataReader[BatchT], ...):
    self._model = unit.model
    self._unit = unit
    self._data_reader = data_reader
    # ... initialize other components
  1. Training Loop: The training loop is implemented in the _do_run method:

def _do_run(self) -> None:
    while self._should_run_step():
        self._step_nr += 1

        # Run training step
        self._run_step()

        # Maybe validate
        if self._should_validate():
            self._validate()

        # Maybe checkpoint
        if self._should_checkpoint():
            self._checkpoint()
  1. Step Execution: The _run_step method is responsible for executing a single training step:

def _run_step(self) -> None:
    # Collect batches
    batches = self._next_batches()

    # Process each batch
    for batch in batches:
        # Forward pass
        loss, num_targets = self._unit(batch)

        # Backward pass
        self._loss_scaler.backward(loss)

        # Update parameters
        self._loss_scaler.run_optimizer_step(self._step_nr)
  1. Validation: The validation loop is implemented in the _validate method:

def _validate(self) -> None:
    self._model.eval()

    with summon_fsdp_for_validation(self._model):
        unit_scores = []

        for unit, data_reader in zip(self._valid_units, self._valid_data_readers):
            unit_score = self._validate_unit(unit, data_reader)
            if unit_score is not None:
                unit_scores.append(unit_score)

        self._valid_score = self._compute_valid_score(unit_scores)

    self._model.train()
  1. Checkpoint Management: The trainer supports flexible checkpoint management:
    • Save checkpoints at regular intervals (steps or epochs)

    • Keep N most recent checkpoints

    • Keep N best checkpoints based on validation score

    • Separate policies for full checkpoints vs model-only checkpoints

    • Support for FSDP model consolidation

  2. Metrics Logging: The trainer supports multiple logging backends:
    • TensorBoard: Visualize training curves

    • JSON Logs: Store metrics in files

    • Weights & Biases (WandB): Collaborative experiment tracking

Best Practices

  1. Metric Tracking: - Register all relevant metrics in the train unit - Use appropriate metric types (Mean, Sum, etc.) - Consider adding validation metrics

  2. Resource Management: - Use appropriate batch sizes for your hardware - Enable amp for memory efficiency - Configure gradient accumulation as needed

  3. Checkpoint Management: - Save checkpoints regularly - Use both keep_last_n_checkpoints and keep_best_n_checkpoints - Consider separate policies for full checkpoints vs models

  4. Validation: - Validate at appropriate intervals - Track relevant validation metrics - Implement early stopping if needed

Advanced Features

  1. Early Stopping:

    def early_stopper(step_nr: int, score: float) -> bool:
        # Custom early stopping logic
        return score < threshold
    
    metric_descriptors = get_runtime_context().get_registry(MetricDescriptor)
    
    try:
        score_metric_descriptor = metric_descriptors.get(metric_name)
    except LookupError:
        raise UnknownMetricDescriptorError(metric_name) from None
    
    trainer = Trainer(
        early_stopper=early_stopper,
        score_metric_descriptor=score_metric_descriptor,
        lower_better=True,
    )
    
  2. Custom Learning Rate Scheduling:

    class CustomLRScheduler(LRScheduler):
        def get_lr(self) -> float:
            # Custom LR calculation
            return self.base_lr * decay_factor(self.step_nr)
    
    trainer = Trainer(
        lr_scheduler=CustomLRScheduler(optimizer),
    )
    
  3. Profiling:

    num_skip_steps, num_record_steps = (100, 10)
    
    profile_dir = Path("logs/tb")
    
    profiler = TorchProfiler(
        num_skip_steps, num_record_steps, profile_dir, gangs.root
    )
    
    trainer = Trainer(
        profiler=profiler,
        ...
    )