fairseq2.recipes.trainer¶
classDiagram ABC <|-- AbstractAsyncContextManager ABC <|-- AbstractContextManager ABC <|-- CheckpointManager ABC <|-- DataReader ABC <|-- DeviceStatTracker ABC <|-- EarlyStopper ABC <|-- EvalUnit ABC <|-- GarbageCollector ABC <|-- Joinable ABC <|-- Metric ABC <|-- MetricRecorder ABC <|-- Profiler ABC <|-- ProgressReporter ABC <|-- ProgressTask ABC <|-- StateHandler ABC <|-- TrainUnit AbstractAsyncContextManager <|-- nullcontext AbstractContextManager <|-- nullcontext BaseException <|-- Exception Collection <|-- Sequence Container <|-- Collection ContextDecorator <|-- record_function EarlyStopper <|-- NoopEarlyStopper Exception <|-- CheckpointNotFoundError Exception <|-- ContractError Exception <|-- InternalError Exception <|-- InvalidOperationError Generic <|-- EvalUnit Generic <|-- Metric Generic <|-- StateHandler Generic <|-- TrainUnit Generic <|-- Trainer Iterable <|-- Collection Iterable <|-- Iterator Iterable <|-- Reversible Iterator <|-- DataReader Joinable <|-- DistributedDataParallel LRScheduler <|-- _LRScheduler Metric <|-- Mean Module <|-- DistributedDataParallel Module <|-- FullyShardedDataParallel ProgressReporter <|-- NoopProgressReporter Reversible <|-- Sequence Sized <|-- Collection StateHandler <|-- FSDPOptimizerStateHandler StatefulObjectBag <|-- Trainer TensorBase <|-- Tensor TrainUnit <|-- AbstractTrainUnit _FSDPState <|-- FullyShardedDataParallel _State <|-- _FSDPState
Classes¶
- final class fairseq2.recipes.trainer.Trainer(*, unit, data_reader, gangs, dtype, amp, optimizer, lr_scheduler, checkpoint_manager, metric_recorder, garbage_collector, profiler, device_stat_tracker, seed, wall_watch, fp16_loss_scale=(128.0, 0.0001), max_gradient_norm=None, max_num_steps=None, max_num_data_epochs=None, score_metric_descriptor=None, lower_better=False, early_stopper=None, valid_units=None, valid_data_readers=None, validate_after_n_steps=0, validate_every_n_steps=None, validate_after_n_data_epochs=0, validate_every_n_data_epochs=None, checkpoint_after_n_steps=0, checkpoint_every_n_steps=None, checkpoint_after_n_data_epochs=0, checkpoint_every_n_data_epochs=None, keep_last_n_checkpoints=None, keep_best_n_checkpoints=None, keep_last_n_models=None, keep_best_n_models=None, publish_metrics_after_n_steps=0, publish_metrics_every_n_steps=None, publish_metrics_after_n_data_epochs=0, publish_metrics_every_n_data_epochs=None, gradient_check=False, anomaly_detection=False)[source]¶
Bases:
StatefulObjectBag
,Generic
[BatchT
]Trains a machine learning model.
- Parameters:
unit (TrainUnit[BatchT]) – The training unit.
data_reader (DataReader[BatchT]) – The data reader for training.
gangs (Gangs) – The gangs to train on.
optimizer (Optimizer) – The parameter optimizer.
checkpoint_manager (CheckpointManager) – The checkpoint manager.
wall_watch (Stopwatch) – The stopwatch to track process wall-time.
dtype (DataType) – The data type of the model.
lr_scheduler (LRScheduler) – The learning rate scheduler.
amp (bool) – If
True
, enablestorch.amp
.fp16_loss_scale (tuple[float, float]) – The initial and minimum loss scale for fp16 training.
max_gradient_norm (float | None) – The maximum gradient norm. If
None
, no clipping will be applied.max_num_steps (int | None) – The maximum number of steps to train for.
max_num_data_epochs (int | None) – The maximum number of data epochs to train for.
score_metric_descriptor (MetricDescriptor | None) – The descriptor of the metric to use for score calculation.
lower_better (bool) – If
True
, lower scores are considered better.early_stopper (EarlyStopper | None) – The early-stopper callable.
valid_units (Sequence[EvalUnit[BatchT]] | None) – The evaluation units for validating the model.
valid_data_readers (Sequence[DataReader[BatchT]] | None) – The data readers corresponding to each unit in
valid_units
.validate_after_n_steps (int) – The number of steps after which to start validating the model.
validate_every_n_steps (int | None) – The step interval at which to validate the model.
validate_after_n_data_epochs (int) – The number of data epochs after which to start validating the model.
validate_every_n_data_epochs (int | None) – The data epoch interval at which to validate the model.
checkpoint_after_n_steps (int) – The number of steps after which to start checkpointing.
checkpoint_every_n_steps (int | None) – The step interval at which to checkpoint.
checkpoint_after_n_data_epochs (int) – The number of data epochs after which to start checkpointing.
checkpoint_every_n_data_epochs (int | None) – The data epoch interval at which to checkpoint.
keep_last_n_checkpoints (int | None) – The number of checkpoints to keep. If
None
, none will be deleted.keep_best_n_checkpoints (int | None) – The number of checkpoints to keep based on their validation score. If
None
, none will be deleted.keep_last_n_models (int | None) – The number of checkpoint models to keep. Must be greater than or equal to
keep_last_n_checkpoints
.keep_best_n_models (int | None) – The number of best checkpoint models to keep based on their validation score. Must be greater than or equal to
keep_best_n_checkpoints
.metric_recorder (MetricRecorder) – The metric recorder.
publish_metrics_after_n_steps (int) – The number of steps after which to start publishing metrics.
publish_metrics_every_n_steps (int | None) – The step interval at which to publish metrics.
publish_metrics_after_n_data_epochs (int) – The number of data epochs after which to start publishing metrics.
publish_metrics_every_n_data_epochs (int | None) – The data epoch interval at which to publish metrics.
profile – The profiler.
anomaly_detection (bool) – If
True
, turns on anomaly detection feature intorch.autograd
.seed (int) – The random number generator seed.