fairseq2.recipes.trainer¶
classDiagram ABC <|-- AbstractAsyncContextManager ABC <|-- AbstractContextManager ABC <|-- AbstractLRScheduler ABC <|-- CheckpointManager ABC <|-- DataReader ABC <|-- EvalUnit ABC <|-- Gang ABC <|-- Joinable ABC <|-- Metric ABC <|-- MetricRecorder ABC <|-- StateHandler ABC <|-- TrainUnit AbstractAsyncContextManager <|-- nullcontext AbstractContextManager <|-- nullcontext AbstractGang <|-- FakeGang AbstractLRScheduler <|-- NoopLR BaseException <|-- Exception CheckpointError <|-- CheckpointNotFoundError Collection <|-- Sequence Container <|-- Collection ContextDecorator <|-- record_function Exception <|-- CheckpointError Exception <|-- ContractError Exception <|-- InternalError Exception <|-- InvalidOperationError Gang <|-- AbstractGang Generic <|-- EvalUnit Generic <|-- Metric Generic <|-- Protocol Generic <|-- StateHandler Generic <|-- TrainUnit Generic <|-- Trainer Iterable <|-- Collection Iterable <|-- Iterator Iterable <|-- Reversible Iterator <|-- DataReader Joinable <|-- DistributedDataParallel JupyterMixin <|-- Progress LRScheduler <|-- _LRScheduler Metric <|-- Mean MetricRecorder <|-- JsonFileMetricRecorder MetricRecorder <|-- LogMetricRecorder MetricRecorder <|-- TensorBoardRecorder MetricRecorder <|-- WandbRecorder Module <|-- DistributedDataParallel Module <|-- FullyShardedDataParallel Protocol <|-- EarlyStopper PurePath <|-- Path Reversible <|-- Sequence Sized <|-- Collection StateHandler <|-- FSDPOptimizerStateHandler StatefulObjectBag <|-- Trainer TensorBase <|-- Tensor TrainUnit <|-- AbstractTrainUnit _FSDPState <|-- FullyShardedDataParallel _LRScheduler <|-- AbstractLRScheduler _State <|-- _FSDPState
Classes¶
- final class fairseq2.recipes.trainer.Trainer(*, unit, data_reader, root_gang, optimizer, checkpoint_manager, wall_watch, dp_gang=None, tp_gang=None, dtype=torch.float32, lr_scheduler=None, fp16_loss_scale=(128.0, 0.0001), max_gradient_norm=None, amp=False, max_num_steps=None, max_num_data_epochs=None, score_metric_name=None, lower_better=False, early_stopper=None, valid_units=None, valid_data_readers=None, validate_after_n_steps=0, validate_every_n_steps=None, validate_after_n_data_epochs=0, validate_every_n_data_epochs=None, checkpoint_after_n_steps=0, checkpoint_every_n_steps=None, checkpoint_after_n_data_epochs=0, checkpoint_every_n_data_epochs=None, keep_last_n_checkpoints=None, keep_best_n_checkpoints=None, keep_last_n_models=None, keep_best_n_models=None, metric_recorders=None, tb_dir=None, metrics_dir=None, wandb_options=None, publish_metrics_after_n_steps=0, publish_metrics_every_n_steps=None, publish_metrics_after_n_data_epochs=0, publish_metrics_every_n_data_epochs=None, profile=None, anomaly_detection=False, seed=2)[source]¶
Bases:
StatefulObjectBag
,Generic
[BatchT
]Trains a machine learning model.
- Parameters:
unit (TrainUnit[BatchT]) – The training unit.
data_reader (DataReader[BatchT]) – The data reader for training.
root_gang (Gang) – The gang for distributed training.
optimizer (Optimizer) – The parameter optimizer.
checkpoint_manager (CheckpointManager) – The checkpoint manager.
wall_watch (Stopwatch) – The stopwatch to track process wall-time.
dtype (DataType) – The data type of the model.
dp_gang (Gang | None) – The data parallel gang. If
None
,gang
will be used.tp_gang (Gang | None) – The tensor parallel gang. Only required for tensor parallel models.
lr_scheduler (LRScheduler | None) – The learning rate scheduler.
amp (bool) – If
True
, enablestorch.amp
.fp16_loss_scale (tuple[float, float]) – The initial and minimum loss scale for fp16 training.
max_gradient_norm (float | None) – The maximum gradient norm. If
None
, no clipping will be applied.max_num_steps (int | None) – The maximum number of steps to train for.
max_num_data_epochs (int | None) – The maximum number of data epochs to train for.
score_metric_name (str | None) – The name of the metric to use for score calculation.
lower_better (bool) – If
True
, lower scores are considered better.early_stopper (EarlyStopper | None) – The early-stopper callable.
valid_units (Sequence[EvalUnit[BatchT]] | None) – The evaluation units for validating the model.
valid_data_readers (Sequence[DataReader[BatchT]] | None) – The data readers corresponding to each unit in
valid_units
.validate_after_n_steps (int) – The number of steps after which to start validating the model.
validate_every_n_steps (int | None) – The step interval at which to validate the model.
validate_after_n_data_epochs (int) – The number of data epochs after which to start validating the model.
validate_every_n_data_epochs (int | None) – The data epoch interval at which to validate the model.
checkpoint_after_n_steps (int) – The number of steps after which to start checkpointing.
checkpoint_every_n_steps (int | None) – The step interval at which to checkpoint.
checkpoint_after_n_data_epochs (int) – The number of data epochs after which to start checkpointing.
checkpoint_every_n_data_epochs (int | None) – The data epoch interval at which to checkpoint.
keep_last_n_checkpoints (int | None) – The number of checkpoints to keep. If
None
, none will be deleted.keep_best_n_checkpoints (int | None) – The number of checkpoints to keep based on their validation score. If
None
, none will be deleted.keep_last_n_models (int | None) – The number of checkpoint models to keep. Must be greater than or equal to
keep_last_n_checkpoints
.keep_best_n_models (int | None) – The number of best checkpoint models to keep based on their validation score. Must be greater than or equal to
keep_best_n_checkpoints
.metric_recorders (Iterable[MetricRecorder] | None) – The metric recorders.
tb_dir (Path | None) – Legacy. Use
metric_recorders
.metrics_dir (Path | None) – Legacy. Use
metric_recorders
.wandb_options (tuple[Path, str, str] | None) – Legacy. Use
metric_recorders
.publish_metrics_after_n_steps (int) – The number of steps after which to start publishing metrics.
publish_metrics_every_n_steps (int | None) – The step interval at which to publish metrics.
publish_metrics_after_n_data_epochs (int) – The number of data epochs after which to start publishing metrics.
publish_metrics_every_n_data_epochs (int | None) – The data epoch interval at which to publish metrics.
profile (tuple[int, int] | None) – The number of steps that the PyTorch profiler should skip and then record.
anomaly_detection (bool) – If
True
, turns on anomaly detection feature intorch.autograd
.seed (int) – The random number generator seed.