Source code for fairseq2.recipes.trainer

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager, nullcontext
from itertools import count
from pathlib import Path
from statistics import mean
from typing import Generic, TypeVar, final

import torch
import torch.distributed
from rich.progress import Progress, TaskID
from torch import Tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.profiler import record_function
from torcheval.metrics import Mean
from typing_extensions import override

from fairseq2.checkpoint import CheckpointManager, CheckpointNotFoundError
from fairseq2.datasets import DataReader
from fairseq2.error import ContractError, InternalError, InvalidOperationError
from fairseq2.gang import FakeGang, Gang, broadcast_flag
from fairseq2.logging import log
from fairseq2.metrics import (
    JsonFileMetricRecorder,
    LogMetricRecorder,
    MetricBag,
    MetricRecorder,
    TensorBoardRecorder,
    WandbRecorder,
    format_metric_value,
    record_metrics,
)
from fairseq2.nn.fsdp import summon_fsdp_for_validation
from fairseq2.nn.utils.gradient import (
    check_gradient_norms,
    clip_gradient_norm,
    normalize_gradients,
)
from fairseq2.optim import DynamicLossScaler
from fairseq2.optim.lr_scheduler import LRScheduler, NoopLR, get_effective_lr
from fairseq2.recipes.common_metrics import extend_batch_metrics
from fairseq2.recipes.early_stopper import EarlyStopper, NoopEarlyStopper
from fairseq2.recipes.evaluator import EvalUnit
from fairseq2.recipes.utils.rich import create_rich_progress
from fairseq2.typing import CPU, DataType
from fairseq2.utils.profiler import Profiler, Stopwatch
from fairseq2.utils.rng import RngBag
from fairseq2.utils.state import FSDPOptimizerStateHandler, StatefulObjectBag

BatchT = TypeVar("BatchT")

BatchT_contra = TypeVar("BatchT_contra", contravariant=True)


class TrainUnit(ABC, Generic[BatchT_contra]):
    """Represents a unit to be used with :class:`Trainer`."""

    @abstractmethod
    def __call__(self, batch: BatchT_contra) -> tuple[Tensor, int | None]:
        """Process ``batch``.

        :returns:
            - The loss.
            - The number of targets used to compute the loss. If ``None``, the
              model gradients won't be normalized.
        """

    @abstractmethod
    def set_step_nr(self, step_nr: int) -> None:
        """Set the current training step number."""

    @property
    @abstractmethod
    def model(self) -> Module:
        """The underlying model."""

    @property
    @abstractmethod
    def metric_bag(self) -> MetricBag:
        """The training-related metrics."""


class AbstractTrainUnit(TrainUnit[BatchT]):
    """Provides a skeletal implementation of :class:`TrainUnit`."""

    def __init__(self, model: Module) -> None:
        self._model = model

    @override
    def set_step_nr(self, step_nr: int) -> None:
        pass

    @final
    @property
    @override
    def model(self) -> Module:
        return self._model


[docs] @final class Trainer(StatefulObjectBag, Generic[BatchT]): """Trains a machine learning model.""" _model: Module _unit: TrainUnit[BatchT] _data_reader: DataReader[BatchT] _root_gang: Gang _dp_gang: Gang _tp_gang: Gang _dtype: DataType _optimizer: Optimizer _lr_scheduler: LRScheduler _loss_scaler: DynamicLossScaler _max_gradient_norm: float | None _amp: bool _step_nr: int _max_num_steps: int | None _data_epoch_nr: int _max_num_data_epochs: int | None _repeat_step: bool _read_data: bool _num_effective_batches: int _end_of_data_epoch: bool _end_of_data: bool _should_stop: bool _score_metric_name: str | None _lower_better: bool _early_stopper: EarlyStopper | None _best_step_and_score: tuple[int, float] | None _valid_score: float | None _valid_units: Sequence[EvalUnit[BatchT]] _valid_data_readers: Sequence[DataReader[BatchT]] _validate_after_n_steps: int _validate_every_n_steps: int | None _validate_after_n_data_epochs: int _validate_every_n_data_epochs: int | None _checkpoint_manager: CheckpointManager _checkpoint_after_n_steps: int _checkpoint_every_n_steps: int | None _checkpoint_after_n_data_epochs: int _checkpoint_every_n_data_epochs: int | None _keep_last_n_checkpoints: int | None _keep_best_n_checkpoints: int | None _keep_last_n_models: int | None _keep_best_n_models: int | None _metric_bag: MetricBag _metric_recorders: list[MetricRecorder] _publish_metrics_after_n_steps: int _publish_metrics_every_n_steps: int | None _publish_metrics_after_n_data_epochs: int _publish_metrics_every_n_data_epochs: int | None _profiler: Profiler _anomaly_detection: bool _seed: int _rng_bag: RngBag _wall_watch: Stopwatch _total_step_time: float _run: bool _progress: Progress _train_task_id: TaskID def __init__( self, *, unit: TrainUnit[BatchT], data_reader: DataReader[BatchT], root_gang: Gang, optimizer: Optimizer, checkpoint_manager: CheckpointManager, wall_watch: Stopwatch, dp_gang: Gang | None = None, tp_gang: Gang | None = None, dtype: DataType = torch.float32, lr_scheduler: LRScheduler | None = None, fp16_loss_scale: tuple[float, float] = (128.0, 0.0001), max_gradient_norm: float | None = None, amp: bool = False, max_num_steps: int | None = None, max_num_data_epochs: int | None = None, score_metric_name: str | None = None, lower_better: bool = False, early_stopper: EarlyStopper | None = None, valid_units: Sequence[EvalUnit[BatchT]] | None = None, valid_data_readers: Sequence[DataReader[BatchT]] | None = None, validate_after_n_steps: int = 0, validate_every_n_steps: int | None = None, validate_after_n_data_epochs: int = 0, validate_every_n_data_epochs: int | None = None, checkpoint_after_n_steps: int = 0, checkpoint_every_n_steps: int | None = None, checkpoint_after_n_data_epochs: int = 0, checkpoint_every_n_data_epochs: int | None = None, keep_last_n_checkpoints: int | None = None, keep_best_n_checkpoints: int | None = None, keep_last_n_models: int | None = None, keep_best_n_models: int | None = None, metric_recorders: Iterable[MetricRecorder] | None = None, tb_dir: Path | None = None, metrics_dir: Path | None = None, wandb_options: tuple[Path, str, str] | None = None, publish_metrics_after_n_steps: int = 0, publish_metrics_every_n_steps: int | None = None, publish_metrics_after_n_data_epochs: int = 0, publish_metrics_every_n_data_epochs: int | None = None, profile: tuple[int, int] | None = None, anomaly_detection: bool = False, seed: int = 2, ) -> None: """ :param unit: The training unit. :param data_reader: The data reader for training. :param root_gang: The gang for distributed training. :param optimizer: The parameter optimizer. :param checkpoint_manager: The checkpoint manager. :param wall_watch: The stopwatch to track process wall-time. :param dtype: The data type of the model. :param dp_gang: The data parallel gang. If ``None``, ``gang`` will be used. :param tp_gang: The tensor parallel gang. Only required for tensor parallel models. :param lr_scheduler: The learning rate scheduler. :param amp: If ``True``, enables ``torch.amp``. :param fp16_loss_scale: The initial and minimum loss scale for fp16 training. :param max_gradient_norm: The maximum gradient norm. If ``None``, no clipping will be applied. :param max_num_steps: The maximum number of steps to train for. :param max_num_data_epochs: The maximum number of data epochs to train for. :param score_metric_name: The name of the metric to use for score calculation. :param lower_better: If ``True``, lower scores are considered better. :param early_stopper: The early-stopper callable. :param valid_units: The evaluation units for validating the model. :param valid_data_readers: The data readers corresponding to each unit in ``valid_units``. :param validate_after_n_steps: The number of steps after which to start validating the model. :param validate_every_n_steps: The step interval at which to validate the model. :param validate_after_n_data_epochs: The number of data epochs after which to start validating the model. :param validate_every_n_data_epochs: The data epoch interval at which to validate the model. :param checkpoint_after_n_steps: The number of steps after which to start checkpointing. :param checkpoint_every_n_steps: The step interval at which to checkpoint. :param checkpoint_after_n_data_epochs: The number of data epochs after which to start checkpointing. :param checkpoint_every_n_data_epochs: The data epoch interval at which to checkpoint. :param keep_last_n_checkpoints: The number of checkpoints to keep. If ``None``, none will be deleted. :param keep_best_n_checkpoints: The number of checkpoints to keep based on their validation score. If ``None``, none will be deleted. :param keep_last_n_models: The number of checkpoint models to keep. Must be greater than or equal to ``keep_last_n_checkpoints``. :param keep_best_n_models: The number of best checkpoint models to keep based on their validation score. Must be greater than or equal to ``keep_best_n_checkpoints``. :param metric_recorders: The metric recorders. :param tb_dir: Legacy. Use ``metric_recorders``. :param metrics_dir: Legacy. Use ``metric_recorders``. :param wandb_options: Legacy. Use ``metric_recorders``. :param publish_metrics_after_n_steps: The number of steps after which to start publishing metrics. :param publish_metrics_every_n_steps: The step interval at which to publish metrics. :param publish_metrics_after_n_data_epochs: The number of data epochs after which to start publishing metrics. :param publish_metrics_every_n_data_epochs: The data epoch interval at which to publish metrics. :param profile: The number of steps that the PyTorch profiler should skip and then record. :param anomaly_detection: If ``True``, turns on anomaly detection feature in ``torch.autograd``. :param seed: The random number generator seed. """ super().__init__() device = root_gang.device self._model = unit.model self._unit = unit self._data_reader = data_reader self._root_gang = root_gang if dp_gang is not None and tp_gang is not None: self._dp_gang = dp_gang self._tp_gang = tp_gang elif dp_gang is None and tp_gang is None: self._dp_gang = root_gang self._tp_gang = FakeGang(device=device) else: raise ValueError("`dp_gang` and `tp_gang` must be both specified.") if root_gang.rank == 0: if self._dp_gang.rank != 0 or self._tp_gang.rank != 0: raise ValueError( f"The coordinator process of `root_gang` (i.e. rank 0) must be rank 0 in `dp_gang` and `tp_gang`, but is {self._dp_gang.rank} and {self._tp_gang.rank} instead." ) self._dtype = dtype uses_fsdp = isinstance(self._model, FSDP) if uses_fsdp: self.register_stateful( "_optimizer", optimizer, FSDPOptimizerStateHandler(self._model) ) else: self._optimizer = optimizer self._lr_scheduler = lr_scheduler or NoopLR(optimizer) fp16_init_scale, fp16_min_scale = fp16_loss_scale self._loss_scaler = DynamicLossScaler( optimizer, root_gang, sharded=uses_fsdp or self._tp_gang.size > 1, init_scale=fp16_init_scale, min_scale=fp16_min_scale, gradient_accumulation=self._data_reader.num_accumulate, enabled=self._dtype == torch.float16, ) self._max_gradient_norm = max_gradient_norm self._amp = amp self.register_stateful("_step_nr", 0) if max_num_steps is not None: if max_num_steps <= 0: raise ValueError("`max_num_steps` must be greater than zero.") self._max_num_steps = max_num_steps self.register_stateful("_data_epoch_nr", 1) if max_num_data_epochs is not None: if max_num_data_epochs <= 0: raise ValueError("`max_num_data_epochs` must be greater than zero.") self._max_num_data_epochs = max_num_data_epochs self._repeat_step = False self._read_data = False # Indicates whether we have read any data. self._num_effective_batches = 0 self._end_of_data_epoch = False self._end_of_data = False self._should_stop = False self._score_metric_name = score_metric_name self._lower_better = lower_better if early_stopper is not None: if score_metric_name is None: raise ValueError( "`score_metric_name` must be specified when `early_stopper` is specified." ) if root_gang.rank != 0: early_stopper = NoopEarlyStopper() self._early_stopper = early_stopper else: self._early_stopper = None self.register_stateful("_best_step_and_score", None) self._valid_score = None if valid_units is None and valid_data_readers is None: self._valid_units = [] self._valid_data_readers = [] elif valid_units is not None and valid_data_readers is not None: if len(valid_units) != len(valid_data_readers): raise ValueError( f"The number of data readers in `valid_data_readers` must match the number of units in `valid_units` ({len(valid_units)}), but is {len(valid_data_readers)} instead." ) self._valid_units = valid_units self._valid_data_readers = valid_data_readers else: raise ValueError( "`valid_units` and `valid_data_readers` must be both specified." ) if validate_every_n_steps is not None: if validate_every_n_steps <= 0: raise ValueError("`validate_every_n_steps` must be greater than zero.") self._validate_after_n_steps = validate_after_n_steps self._validate_every_n_steps = validate_every_n_steps if validate_every_n_data_epochs is not None: if validate_every_n_data_epochs <= 0: raise ValueError( "`validate_every_n_data_epochs` must be greater than zero." ) self._validate_after_n_data_epochs = validate_after_n_data_epochs self._validate_every_n_data_epochs = validate_every_n_data_epochs self._checkpoint_manager = checkpoint_manager if checkpoint_every_n_steps is not None: if checkpoint_every_n_steps <= 0: raise ValueError( "`checkpoint_every_n_steps` must be greater than zero." ) self._checkpoint_after_n_steps = checkpoint_after_n_steps self._checkpoint_every_n_steps = checkpoint_every_n_steps if checkpoint_every_n_data_epochs is not None: if checkpoint_every_n_data_epochs <= 0: raise ValueError( "`checkpoint_every_n_data_epochs` must be greater than zero." ) self._checkpoint_after_n_data_epochs = checkpoint_after_n_data_epochs self._checkpoint_every_n_data_epochs = checkpoint_every_n_data_epochs if keep_last_n_checkpoints is not None: if keep_best_n_checkpoints is not None: raise ValueError( "`keep_last_n_checkpoints` and `keep_best_n_checkpoints` are mutually exclusive and must not be specified at the same time." ) if keep_last_n_checkpoints <= 0: raise ValueError("`keep_last_n_checkpoints` must be greater than zero.") elif keep_best_n_checkpoints is not None: if keep_best_n_checkpoints <= 0: raise ValueError("`keep_best_n_checkpoints` must be greater than zero.") if checkpoint_every_n_steps is not None: if score_metric_name is None: raise ValueError( "`score_metric_name` must be specified when `keep_best_n_checkpoints` is specified." ) if validate_every_n_steps is None: raise ValueError( "`validate_every_n_steps` must be specified when `keep_best_n_checkpoints` is specified." ) if checkpoint_every_n_steps % validate_every_n_steps != 0: raise ValueError( f"`checkpoint_every_n_steps` must be a multiple of `validate_every_n_steps` ({validate_every_n_steps}) when `keep_best_n_checkpoints` is specified, but is {checkpoint_every_n_steps} instead." ) self._keep_last_n_checkpoints = keep_last_n_checkpoints self._keep_best_n_checkpoints = keep_best_n_checkpoints if keep_last_n_models is not None: if keep_last_n_checkpoints is None: raise ValueError( "`keep_last_n_models` must be `None` when `keep_last_n_checkpoints` is `None`." ) if keep_last_n_checkpoints > keep_last_n_models: raise ValueError( f"`keep_last_n_models` must be greater than or equal to `keep_last_n_checkpoints` ({keep_last_n_checkpoints}), but is {keep_last_n_models} instead." ) if keep_best_n_models is not None: if keep_best_n_checkpoints is None: raise ValueError( "`keep_best_n_models` must be `None` when `keep_best_n_checkpoints` is `None`." ) if keep_best_n_checkpoints > keep_best_n_models: raise ValueError( f"`keep_best_n_models` must be greater than or equal to `keep_best_n_checkpoints` ({keep_best_n_checkpoints}), but is {keep_best_n_models} instead." ) self._keep_last_n_models = keep_last_n_models self._keep_best_n_models = keep_best_n_models unit.metric_bag.register_metric( "gradient_norm", Mean(device=device), persistent=False ) self._metric_bag = unit.metric_bag if metric_recorders is None: # compat if root_gang.rank == 0: self._metric_recorders = [LogMetricRecorder(log)] if tb_dir is not None: self._metric_recorders.append(TensorBoardRecorder(tb_dir)) if metrics_dir is not None: self._metric_recorders.append(JsonFileMetricRecorder(metrics_dir)) if wandb_options is not None: wandb_dir, wandb_project, wandb_name = wandb_options self._metric_recorders.append( WandbRecorder(wandb_project, wandb_name, wandb_dir) ) else: self._metric_recorders = [] else: self._metric_recorders = list(metric_recorders) if publish_metrics_every_n_steps == 0: raise ValueError( "`publish_metrics_every_n_steps` must be greater than zero." ) self._publish_metrics_after_n_steps = publish_metrics_after_n_steps self._publish_metrics_every_n_steps = publish_metrics_every_n_steps if publish_metrics_every_n_data_epochs == 0: raise ValueError( "`publish_metrics_every_n_data_epochs` must be greater than zero." ) self._publish_metrics_after_n_data_epochs = publish_metrics_after_n_data_epochs self._publish_metrics_every_n_data_epochs = publish_metrics_every_n_data_epochs if profile is None or tb_dir is None: if profile is not None and tb_dir is None: log.warning("No TensorBoard log directory provided. Profiling will be disabled.") # fmt: skip num_skip_steps, num_record_steps = 1, 0 profile_dir = Path() enabled = False else: num_skip_steps, num_record_steps = profile profile_dir = tb_dir enabled = num_record_steps > 0 self._profiler = Profiler( num_skip_steps, num_record_steps, profile_dir, root_gang, enabled=enabled ) self._anomaly_detection = anomaly_detection self._seed = seed self._rng_bag = RngBag.from_device_defaults(CPU, device) self._wall_watch = wall_watch self._total_step_time = 0.0 self._run = False self._progress = create_rich_progress()
[docs] def request_stop(self) -> None: """Request a graceful stop of the training.""" log.info("Stopping training after a final validation and saving checkpoint.") self._should_stop = True
def __call__(self) -> None: if self._run: raise InvalidOperationError("The trainer can only be run once.") self._run = True # Set the per-rank seed for training. self._rng_bag.manual_seed(self._seed + self._root_gang.rank) try: self._maybe_restore_state() except KeyboardInterrupt: log.info("Training terminated!") raise log.info("Running training on {} device(s).", self._root_gang.size) try: self._do_run() except KeyboardInterrupt: log.info("Training terminated at step {}!", self._step_nr) raise if self._should_stop: log.info("Training stopped at step {}!", self._step_nr) return elapsed_time = self._wall_watch.get_elapsed_time() log.info("Training complete in {:,} seconds after {} step(s)!", int(elapsed_time), self._step_nr) # fmt: skip def _maybe_restore_state(self) -> None: log.info("Attempting to load the last checkpoint.") try: step_nr, state = self._checkpoint_manager.load_last_checkpoint() except CheckpointNotFoundError: log.info("No checkpoint found. Starting training.") return log.info("Checkpoint loaded, restoring training from step {}.", step_nr) self._step_nr = step_nr self.load_state_dict(state) self._root_gang.barrier() log.info("Training restored, resuming.") def _do_run(self) -> None: with self._progress, self._profiler: self._train_task_id = self._progress.add_task( "train", total=self._max_num_steps, completed=self._step_nr ) while self._should_run_step(): self._maybe_advance_data_epoch() self._step_nr += 1 self._progress.update(self._train_task_id, advance=1) detect_anomaly = torch.autograd.set_detect_anomaly( # type: ignore[attr-defined] self._anomaly_detection, check_nan=True ) with detect_anomaly: with record_function(f"step_{self._step_nr}"): self._run_step() if self._should_publish_metrics(): self._publish_metrics() if self._should_validate(): self._validate() self._maybe_request_early_stop() if self._should_checkpoint(): self._checkpoint() self._profiler.step() self._valid_score = None def _should_run_step(self) -> bool: if self._end_of_data or self._should_stop: return False if self._max_num_steps is None: return True return self._step_nr < self._max_num_steps def _maybe_advance_data_epoch(self) -> None: if not self._end_of_data_epoch: return self._data_epoch_nr += 1 self._end_of_data_epoch = False def _run_step(self) -> None: step_nr = self._step_nr log.debug("{} training step {}.", "Repeating" if self._repeat_step else "Running", step_nr) # fmt: skip watch = Stopwatch(start=True, device=self._root_gang.device) # Collect the batches. with record_function(f"step_{step_nr}_data_load"): batches = self._next_batches() if batches is None: return # Prepare the unit. if not self._repeat_step: with record_function(f"step_{step_nr}_unit_setup"): self._unit.set_step_nr(step_nr) num_targets = 0 if self._loss_scaler.is_enabled: self._metric_bag.begin_updates() # Accumulate. for batch_nr, batch in enumerate(batches): with self._maybe_no_sync(batch_nr, len(batches)): with record_function(f"step_{step_nr}_{batch_nr}_forward"): batch_loss, num_batch_targets = self._compute_loss(batch) if num_batch_targets is not None: if num_batch_targets == 0: raise ContractError( "The train unit returned zero loss targets." ) num_targets += num_batch_targets with record_function(f"step_{step_nr}_{batch_nr}_backward"): self._loss_scaler.backward(batch_loss) # Normalize. if num_targets > 0: normalize_gradients(self._model, self._dp_gang, num_targets=num_targets) # Clip. with record_function(f"step_{step_nr}_grad_norm"): self._loss_scaler.unscale_gradients_() # TODO(balioglu): Support tensor parallelism! grad_norm = clip_gradient_norm( self._model, max_norm=self._max_gradient_norm ) # Sanity check. if not check_gradient_norms(grad_norm, self._dp_gang, step_nr): raise FloatingPointError( f"The gradients are inconsistent between processes at step {step_nr}. Training cannot continue." ) # Update the parameters. with record_function(f"step_{step_nr}_optimizer"): _, scale_result = self._loss_scaler.run_optimizer_step(step_nr) if scale_result.overflow: self._metric_bag.rollback_updates() if scale_result.min_reached: raise FloatingPointError( f"The gradients are scaled down to minimum at step {step_nr}. Training cannot continue." ) # Repeat the step with the next batch. self._step_nr -= 1 self._progress.update(self._train_task_id, advance=-1) self._repeat_step = True else: self._lr_scheduler.step() if self._loss_scaler.is_enabled: self._metric_bag.commit_updates() self._metric_bag.gradient_norm.update(grad_norm) self._repeat_step = False self._num_effective_batches += 1 # Reset. self._optimizer.zero_grad(set_to_none=True) self._total_step_time += watch.get_elapsed_time() def _next_batches(self) -> list[BatchT] | None: try: batches = next(self._data_reader) except StopIteration: batches = None if batches is not None: self._read_data = True return batches self._data_reader.reset() self._end_of_data_epoch = True log.info("End of epoch {} reached at training step {}.", self._data_epoch_nr, self._step_nr) # fmt: skip if not self._read_data: # The dataset is empty. self._end_of_data = True elif self._max_num_data_epochs is not None: if self._data_epoch_nr >= self._max_num_data_epochs: self._end_of_data = True if self._end_of_data: log.info("End of data reached.", self._step_nr) # Repeat the step with the first batch of the next epoch. self._step_nr -= 1 self._progress.update(self._train_task_id, advance=-1) self._repeat_step = True return None def _maybe_no_sync( self, batch_nr: int, num_batches: int ) -> AbstractContextManager[None]: if batch_nr < num_batches - 1 and self._dp_gang.size > 1: return self._model.no_sync() # type: ignore[no-any-return] return nullcontext() def _compute_loss(self, batch: BatchT) -> tuple[Tensor, int | None]: with self._maybe_autocast(): return self._unit(batch) def _maybe_autocast(self) -> AbstractContextManager[None]: if self._dtype == torch.float32 or not self._amp: return nullcontext() return torch.autocast(device_type=self._dp_gang.device.type, dtype=self._dtype) def _should_publish_metrics(self) -> bool: return self._should_do( self._publish_metrics_after_n_steps, self._publish_metrics_every_n_steps, self._publish_metrics_after_n_data_epochs, self._publish_metrics_every_n_data_epochs, ) def _publish_metrics(self) -> None: log.debug("Syncing metrics.") if self._tp_gang.rank == 0: values = self._metric_bag.sync_and_compute_metrics() else: values = None self._metric_bag.reset_non_persistent_metrics() if self._root_gang.rank == 0: if values is None: raise InternalError("`values` is `None`.") extend_batch_metrics( values, self._num_effective_batches, self._total_step_time ) values["lr"] = get_effective_lr(self._lr_scheduler) values["data_epoch"] = self._data_epoch_nr values["elapsed_time"] = self._total_step_time values["wall_time"] = self._wall_watch.get_elapsed_time() record_metrics(self._metric_recorders, "train", values, self._step_nr) self._num_effective_batches = 0 self._total_step_time = 0.0 def _should_validate(self) -> bool: if not self._valid_units: return False return self._should_do( self._validate_after_n_steps, self._validate_every_n_steps, self._validate_after_n_data_epochs, self._validate_every_n_data_epochs, ) def _validate(self) -> None: log.info("Starting validation after step {}.", self._step_nr) self._model.eval() with summon_fsdp_for_validation(self._model): unit_scores = [] for unit, data_reader in zip(self._valid_units, self._valid_data_readers): if unit.display_name: log.info("Validating {}.", unit.display_name) unit_score = self._validate_unit(unit, data_reader) if unit_score is not None: unit_scores.append(unit_score) self._valid_score = self._compute_valid_score(unit_scores) self._model.train() log.info("Validation complete.") @torch.inference_mode() def _validate_unit( self, unit: EvalUnit[BatchT], data_reader: DataReader[BatchT] ) -> float | None: watch = Stopwatch(start=True, device=self._root_gang.device) unit.set_step_nr(self._step_nr) valid_task = self._progress.add_task("valid", total=None) num_effective_batches = 0 for step_nr in count(start=1): self._progress.update(valid_task, advance=1) log.debug("Running validation step {}.", step_nr) try: batches = next(data_reader) except StopIteration: break for batch in batches: with self._maybe_autocast(): unit(batch) num_effective_batches += 1 self._progress.remove_task(valid_task) data_reader.reset() metric_values = self._publish_validation_metrics( unit, num_effective_batches, watch.get_elapsed_time() ) return self._get_unit_score(metric_values) def _publish_validation_metrics( self, unit: EvalUnit[BatchT], num_batches: int, elapsed_time: float ) -> dict[str, object] | None: log.debug("Syncing validation metrics.") if self._tp_gang.rank == 0: values = unit.metric_bag.sync_and_compute_metrics() else: values = None unit.metric_bag.reset_metrics() if self._root_gang.rank != 0: return None if values is None: raise InternalError("`values` is `None`.") extend_batch_metrics(values, num_batches, elapsed_time) values["data_epoch"] = self._data_epoch_nr values["elapsed_time"] = elapsed_time values["wall_time"] = self._wall_watch.get_elapsed_time() if unit.display_name: run_name = "valid/" + unit.display_name else: run_name = "valid" record_metrics(self._metric_recorders, run_name, values, self._step_nr) return values def _get_unit_score(self, metric_values: dict[str, object] | None) -> float | None: if metric_values is None: return None if self._score_metric_name is None: return None score = metric_values.get(self._score_metric_name) if score is None: return None if not isinstance(score, (int, float, Tensor)): log.warning("The score metric must be of type `int`, `float`, or `torch.Tensor`.") # fmt: skip return None return float(score) def _compute_valid_score(self, unit_scores: list[float]) -> float | None: if self._score_metric_name is None: return None if not unit_scores: if self._root_gang.rank == 0: raise ContractError( "None of the validation units returned a score metric value." ) return None score = mean(unit_scores) def is_better_score() -> bool: if self._best_step_and_score is None: return True best_score = self._best_step_and_score[1] return best_score > score if self._lower_better else best_score < score if is_better_score(): self._best_step_and_score = (self._step_nr, score) if log.is_enabled_for_info(): best_step_nr, best_score = self._best_step_and_score # type: ignore[misc] if len(unit_scores) > 1: m1 = "Mean " m2 = "Best Mean " else: m1 = "" m2 = "Best " s1 = format_metric_value(self._score_metric_name, score) s2 = format_metric_value(self._score_metric_name, best_score) log.info("Score (step {}) - {}{} | {}{} at step {}", self._step_nr, m1, s1, m2, s2, best_step_nr) # fmt: skip return score def _maybe_request_early_stop(self) -> None: if self._early_stopper is None: return if self._root_gang.rank == 0: if self._valid_score is None: raise InternalError("Early stopping, but `_valid_score` is `None`.") should_stop = self._early_stopper.should_stop( self._step_nr, self._valid_score ) else: should_stop = False self._should_stop = broadcast_flag(self._root_gang, should_stop) if self._should_stop: log.info("Early stop requested. Training will be terminated after saving checkpoint.") # fmt: skip def _should_checkpoint(self) -> bool: return self._should_do( self._checkpoint_after_n_steps, self._checkpoint_every_n_steps, self._checkpoint_after_n_data_epochs, self._checkpoint_every_n_data_epochs, ) def _checkpoint(self) -> None: step_nr = self._step_nr log.info("Saving checkpoint after step {}.", step_nr) log.info("Extracting trainer state.") state = self.state_dict() log.info("Trainer state extracted.") self._checkpoint_manager.begin_checkpoint(step_nr) log.info("Saving trainer state.") if self._dp_gang.size > 1 and isinstance(self._model, DDP): replicated_keys = {"_model", "_optimizer"} else: replicated_keys = None self._checkpoint_manager.save_state( state, model_key="_model", replicated_keys=replicated_keys ) log.info("Trainer state saved.") if self._score_metric_name is not None: log.info("Saving checkpoint score.") self._checkpoint_manager.save_score(self._valid_score) log.info("Checkpoint score saved.") if isinstance(self._model, FSDP): log.info("Saving consolidated FSDP model.") self._checkpoint_manager.save_consolidated_fsdp_model(self._model) log.info("Consolidated FSDP model saved.") self._checkpoint_manager.commit_checkpoint() log.info("Checkpoint complete.") # Clean up the checkpoints. nc = self._keep_last_n_checkpoints nm = self._keep_last_n_models if nm is not None: if nc is None: raise InternalError("`_keep_last_n_checkpoints` is `None`") self._checkpoint_manager.keep_last_n_checkpoints(nm) self._checkpoint_manager.keep_last_n_checkpoints(nc, preserve_model=True) elif nc is not None: self._checkpoint_manager.keep_last_n_checkpoints(nc) nc = self._keep_best_n_checkpoints nm = self._keep_best_n_models if nm is not None: if nc is None: raise InternalError("`_keep_best_n_checkpoints` is `None`") self._checkpoint_manager.keep_best_n_checkpoints( nm, lower_better=self._lower_better ) self._checkpoint_manager.keep_best_n_checkpoints( nc, lower_better=self._lower_better, preserve_model=True ) elif nc is not None: self._checkpoint_manager.keep_best_n_checkpoints( nc, lower_better=self._lower_better ) def _should_do( self, after_n_steps: int, every_n_steps: int | None, after_n_data_epochs: int, every_n_data_epochs: int | None, ) -> bool: should_do_at_step = self._should_do_at_step(after_n_steps, every_n_steps) if self._end_of_data or self._should_stop: if not self._read_data: return False return not should_do_at_step if self._end_of_data_epoch and every_n_data_epochs is not None: if self._data_epoch_nr >= after_n_data_epochs: if self._data_epoch_nr % every_n_data_epochs == 0: return not should_do_at_step if self._repeat_step: return False return should_do_at_step def _should_do_at_step(self, after_n_steps: int, every_n_steps: int | None) -> bool: if self._max_num_steps is not None: if self._step_nr >= self._max_num_steps: return True if every_n_steps is not None: if self._step_nr >= after_n_steps: return self._step_nr % every_n_steps == 0 return False