fairseq2.recipes.trainer

        classDiagram
  ABC <|-- AbstractAsyncContextManager
  ABC <|-- AbstractContextManager
  ABC <|-- AbstractLRScheduler
  ABC <|-- CheckpointManager
  ABC <|-- DataReader
  ABC <|-- EvalUnit
  ABC <|-- Gang
  ABC <|-- Joinable
  ABC <|-- Metric
  ABC <|-- MetricRecorder
  ABC <|-- StateHandler
  ABC <|-- TrainUnit
  AbstractAsyncContextManager <|-- nullcontext
  AbstractContextManager <|-- nullcontext
  AbstractGang <|-- FakeGang
  AbstractLRScheduler <|-- NoopLR
  BaseException <|-- Exception
  CheckpointError <|-- CheckpointNotFoundError
  Collection <|-- Sequence
  Container <|-- Collection
  ContextDecorator <|-- record_function
  Exception <|-- CheckpointError
  Exception <|-- ContractError
  Exception <|-- InternalError
  Exception <|-- InvalidOperationError
  Gang <|-- AbstractGang
  Generic <|-- EvalUnit
  Generic <|-- Metric
  Generic <|-- Protocol
  Generic <|-- StateHandler
  Generic <|-- TrainUnit
  Generic <|-- Trainer
  Iterable <|-- Collection
  Iterable <|-- Iterator
  Iterable <|-- Reversible
  Iterator <|-- DataReader
  Joinable <|-- DistributedDataParallel
  JupyterMixin <|-- Progress
  LRScheduler <|-- _LRScheduler
  Metric <|-- Mean
  MetricRecorder <|-- JsonFileMetricRecorder
  MetricRecorder <|-- LogMetricRecorder
  MetricRecorder <|-- TensorBoardRecorder
  MetricRecorder <|-- WandbRecorder
  Module <|-- DistributedDataParallel
  Module <|-- FullyShardedDataParallel
  Protocol <|-- EarlyStopper
  PurePath <|-- Path
  Reversible <|-- Sequence
  Sized <|-- Collection
  StateHandler <|-- FSDPOptimizerStateHandler
  StatefulObjectBag <|-- Trainer
  TensorBase <|-- Tensor
  TrainUnit <|-- AbstractTrainUnit
  _FSDPState <|-- FullyShardedDataParallel
  _LRScheduler <|-- AbstractLRScheduler
  _State <|-- _FSDPState
    

Classes

final class fairseq2.recipes.trainer.Trainer(*, unit, data_reader, root_gang, optimizer, checkpoint_manager, wall_watch, dp_gang=None, tp_gang=None, dtype=torch.float32, lr_scheduler=None, fp16_loss_scale=(128.0, 0.0001), max_gradient_norm=None, amp=False, max_num_steps=None, max_num_data_epochs=None, score_metric_name=None, lower_better=False, early_stopper=None, valid_units=None, valid_data_readers=None, validate_after_n_steps=0, validate_every_n_steps=None, validate_after_n_data_epochs=0, validate_every_n_data_epochs=None, checkpoint_after_n_steps=0, checkpoint_every_n_steps=None, checkpoint_after_n_data_epochs=0, checkpoint_every_n_data_epochs=None, keep_last_n_checkpoints=None, keep_best_n_checkpoints=None, keep_last_n_models=None, keep_best_n_models=None, metric_recorders=None, tb_dir=None, metrics_dir=None, wandb_options=None, publish_metrics_after_n_steps=0, publish_metrics_every_n_steps=None, publish_metrics_after_n_data_epochs=0, publish_metrics_every_n_data_epochs=None, profile=None, anomaly_detection=False, seed=2)[source]

Bases: StatefulObjectBag, Generic[BatchT]

Trains a machine learning model.

Parameters:
  • unit (TrainUnit[BatchT]) – The training unit.

  • data_reader (DataReader[BatchT]) – The data reader for training.

  • root_gang (Gang) – The gang for distributed training.

  • optimizer (Optimizer) – The parameter optimizer.

  • checkpoint_manager (CheckpointManager) – The checkpoint manager.

  • wall_watch (Stopwatch) – The stopwatch to track process wall-time.

  • dtype (DataType) – The data type of the model.

  • dp_gang (Gang | None) – The data parallel gang. If None, gang will be used.

  • tp_gang (Gang | None) – The tensor parallel gang. Only required for tensor parallel models.

  • lr_scheduler (LRScheduler | None) – The learning rate scheduler.

  • amp (bool) – If True, enables torch.amp.

  • fp16_loss_scale (tuple[float, float]) – The initial and minimum loss scale for fp16 training.

  • max_gradient_norm (float | None) – The maximum gradient norm. If None, no clipping will be applied.

  • max_num_steps (int | None) – The maximum number of steps to train for.

  • max_num_data_epochs (int | None) – The maximum number of data epochs to train for.

  • score_metric_name (str | None) – The name of the metric to use for score calculation.

  • lower_better (bool) – If True, lower scores are considered better.

  • early_stopper (EarlyStopper | None) – The early-stopper callable.

  • valid_units (Sequence[EvalUnit[BatchT]] | None) – The evaluation units for validating the model.

  • valid_data_readers (Sequence[DataReader[BatchT]] | None) – The data readers corresponding to each unit in valid_units.

  • validate_after_n_steps (int) – The number of steps after which to start validating the model.

  • validate_every_n_steps (int | None) – The step interval at which to validate the model.

  • validate_after_n_data_epochs (int) – The number of data epochs after which to start validating the model.

  • validate_every_n_data_epochs (int | None) – The data epoch interval at which to validate the model.

  • checkpoint_after_n_steps (int) – The number of steps after which to start checkpointing.

  • checkpoint_every_n_steps (int | None) – The step interval at which to checkpoint.

  • checkpoint_after_n_data_epochs (int) – The number of data epochs after which to start checkpointing.

  • checkpoint_every_n_data_epochs (int | None) – The data epoch interval at which to checkpoint.

  • keep_last_n_checkpoints (int | None) – The number of checkpoints to keep. If None, none will be deleted.

  • keep_best_n_checkpoints (int | None) – The number of checkpoints to keep based on their validation score. If None, none will be deleted.

  • keep_last_n_models (int | None) – The number of checkpoint models to keep. Must be greater than or equal to keep_last_n_checkpoints.

  • keep_best_n_models (int | None) – The number of best checkpoint models to keep based on their validation score. Must be greater than or equal to keep_best_n_checkpoints.

  • metric_recorders (Iterable[MetricRecorder] | None) – The metric recorders.

  • tb_dir (Path | None) – Legacy. Use metric_recorders.

  • metrics_dir (Path | None) – Legacy. Use metric_recorders.

  • wandb_options (tuple[Path, str, str] | None) – Legacy. Use metric_recorders.

  • publish_metrics_after_n_steps (int) – The number of steps after which to start publishing metrics.

  • publish_metrics_every_n_steps (int | None) – The step interval at which to publish metrics.

  • publish_metrics_after_n_data_epochs (int) – The number of data epochs after which to start publishing metrics.

  • publish_metrics_every_n_data_epochs (int | None) – The data epoch interval at which to publish metrics.

  • profile (tuple[int, int] | None) – The number of steps that the PyTorch profiler should skip and then record.

  • anomaly_detection (bool) – If True, turns on anomaly detection feature in torch.autograd.

  • seed (int) – The random number generator seed.

request_stop()[source]

Request a graceful stop of the training.