fairseq2.recipes.trainer

        classDiagram
  ABC <|-- AbstractAsyncContextManager
  ABC <|-- AbstractContextManager
  ABC <|-- CheckpointManager
  ABC <|-- DataReader
  ABC <|-- DeviceStatTracker
  ABC <|-- EarlyStopper
  ABC <|-- EvalUnit
  ABC <|-- GarbageCollector
  ABC <|-- Joinable
  ABC <|-- Metric
  ABC <|-- MetricRecorder
  ABC <|-- Profiler
  ABC <|-- ProgressReporter
  ABC <|-- ProgressTask
  ABC <|-- StateHandler
  ABC <|-- TrainUnit
  AbstractAsyncContextManager <|-- nullcontext
  AbstractContextManager <|-- nullcontext
  BaseException <|-- Exception
  Collection <|-- Sequence
  Container <|-- Collection
  ContextDecorator <|-- record_function
  EarlyStopper <|-- NoopEarlyStopper
  Exception <|-- CheckpointNotFoundError
  Exception <|-- ContractError
  Exception <|-- InternalError
  Exception <|-- InvalidOperationError
  Generic <|-- EvalUnit
  Generic <|-- Metric
  Generic <|-- StateHandler
  Generic <|-- TrainUnit
  Generic <|-- Trainer
  Iterable <|-- Collection
  Iterable <|-- Iterator
  Iterable <|-- Reversible
  Iterator <|-- DataReader
  Joinable <|-- DistributedDataParallel
  LRScheduler <|-- _LRScheduler
  Metric <|-- Mean
  Module <|-- DistributedDataParallel
  Module <|-- FullyShardedDataParallel
  ProgressReporter <|-- NoopProgressReporter
  Reversible <|-- Sequence
  Sized <|-- Collection
  StateHandler <|-- FSDPOptimizerStateHandler
  StatefulObjectBag <|-- Trainer
  TensorBase <|-- Tensor
  TrainUnit <|-- AbstractTrainUnit
  _FSDPState <|-- FullyShardedDataParallel
  _State <|-- _FSDPState
    

Classes

final class fairseq2.recipes.trainer.Trainer(*, unit, data_reader, gangs, dtype, amp, optimizer, lr_scheduler, checkpoint_manager, metric_recorder, garbage_collector, profiler, device_stat_tracker, seed, wall_watch, fp16_loss_scale=(128.0, 0.0001), max_gradient_norm=None, max_num_steps=None, max_num_data_epochs=None, score_metric_descriptor=None, lower_better=False, early_stopper=None, valid_units=None, valid_data_readers=None, validate_after_n_steps=0, validate_every_n_steps=None, validate_after_n_data_epochs=0, validate_every_n_data_epochs=None, checkpoint_after_n_steps=0, checkpoint_every_n_steps=None, checkpoint_after_n_data_epochs=0, checkpoint_every_n_data_epochs=None, keep_last_n_checkpoints=None, keep_best_n_checkpoints=None, keep_last_n_models=None, keep_best_n_models=None, publish_metrics_after_n_steps=0, publish_metrics_every_n_steps=None, publish_metrics_after_n_data_epochs=0, publish_metrics_every_n_data_epochs=None, gradient_check=False, anomaly_detection=False)[source]

Bases: StatefulObjectBag, Generic[BatchT]

Trains a machine learning model.

Parameters:
  • unit (TrainUnit[BatchT]) – The training unit.

  • data_reader (DataReader[BatchT]) – The data reader for training.

  • gangs (Gangs) – The gangs to train on.

  • optimizer (Optimizer) – The parameter optimizer.

  • checkpoint_manager (CheckpointManager) – The checkpoint manager.

  • wall_watch (Stopwatch) – The stopwatch to track process wall-time.

  • dtype (DataType) – The data type of the model.

  • lr_scheduler (LRScheduler) – The learning rate scheduler.

  • amp (bool) – If True, enables torch.amp.

  • fp16_loss_scale (tuple[float, float]) – The initial and minimum loss scale for fp16 training.

  • max_gradient_norm (float | None) – The maximum gradient norm. If None, no clipping will be applied.

  • max_num_steps (int | None) – The maximum number of steps to train for.

  • max_num_data_epochs (int | None) – The maximum number of data epochs to train for.

  • score_metric_descriptor (MetricDescriptor | None) – The descriptor of the metric to use for score calculation.

  • lower_better (bool) – If True, lower scores are considered better.

  • early_stopper (EarlyStopper | None) – The early-stopper callable.

  • valid_units (Sequence[EvalUnit[BatchT]] | None) – The evaluation units for validating the model.

  • valid_data_readers (Sequence[DataReader[BatchT]] | None) – The data readers corresponding to each unit in valid_units.

  • validate_after_n_steps (int) – The number of steps after which to start validating the model.

  • validate_every_n_steps (int | None) – The step interval at which to validate the model.

  • validate_after_n_data_epochs (int) – The number of data epochs after which to start validating the model.

  • validate_every_n_data_epochs (int | None) – The data epoch interval at which to validate the model.

  • checkpoint_after_n_steps (int) – The number of steps after which to start checkpointing.

  • checkpoint_every_n_steps (int | None) – The step interval at which to checkpoint.

  • checkpoint_after_n_data_epochs (int) – The number of data epochs after which to start checkpointing.

  • checkpoint_every_n_data_epochs (int | None) – The data epoch interval at which to checkpoint.

  • keep_last_n_checkpoints (int | None) – The number of checkpoints to keep. If None, none will be deleted.

  • keep_best_n_checkpoints (int | None) – The number of checkpoints to keep based on their validation score. If None, none will be deleted.

  • keep_last_n_models (int | None) – The number of checkpoint models to keep. Must be greater than or equal to keep_last_n_checkpoints.

  • keep_best_n_models (int | None) – The number of best checkpoint models to keep based on their validation score. Must be greater than or equal to keep_best_n_checkpoints.

  • metric_recorder (MetricRecorder) – The metric recorder.

  • publish_metrics_after_n_steps (int) – The number of steps after which to start publishing metrics.

  • publish_metrics_every_n_steps (int | None) – The step interval at which to publish metrics.

  • publish_metrics_after_n_data_epochs (int) – The number of data epochs after which to start publishing metrics.

  • publish_metrics_every_n_data_epochs (int | None) – The data epoch interval at which to publish metrics.

  • profile – The profiler.

  • anomaly_detection (bool) – If True, turns on anomaly detection feature in torch.autograd.

  • seed (int) – The random number generator seed.

request_stop()[source]

Request a graceful stop of the training.