# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import nullcontext
from itertools import count
from statistics import mean
from typing import Generic, TypeVar, final
import torch
import torch.distributed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.profiler import record_function
from torcheval.metrics import Mean
from fairseq2.checkpoint import CheckpointManager
from fairseq2.datasets import DataReader
from fairseq2.device import DeviceStatTracker
from fairseq2.error import ContractError, InternalError, InvalidOperationError
from fairseq2.gang import Gangs, broadcast_flag
from fairseq2.logging import log
from fairseq2.metrics import MetricBag, MetricDescriptor
from fairseq2.metrics.recorders import MetricRecorder
from fairseq2.nn.utils.gradient import (
check_gradient_norms,
normalize_gradients,
)
from fairseq2.optim import DynamicLossScaler
from fairseq2.optim.lr_scheduler import LRScheduler, get_effective_lr
from fairseq2.profilers import Profiler
from fairseq2.recipes.early_stopper import EarlyStopper, NoopEarlyStopper
from fairseq2.recipes.evaluator import EvalUnit
from fairseq2.recipes.metrics import extend_batch_metrics
from fairseq2.recipes.model import Model
from fairseq2.recipes.utils.progress import (
NoopProgressReporter,
ProgressReporter,
ProgressTask,
)
from fairseq2.typing import CPU, ContextManager, DataType
from fairseq2.utils.gc import GarbageCollector
from fairseq2.utils.rng import RngBag
from fairseq2.utils.state import StatefulObjectBag
from fairseq2.utils.stopwatch import Stopwatch
BatchT_contra = TypeVar("BatchT_contra", contravariant=True)
class TrainUnit(ABC, Generic[BatchT_contra]):
"""Represents a unit to be used with :class:`Trainer`."""
@abstractmethod
def __call__(self, batch: BatchT_contra) -> tuple[Tensor, int | None]:
"""Process ``batch``.
:returns:
- The loss.
- The number of targets used to compute the loss. If ``None``, the
model gradients won't be normalized.
"""
def set_step_nr(self, step_nr: int) -> None:
"""Set the current training step number."""
pass
@property
@abstractmethod
def model(self) -> Model:
"""The underlying model."""
@property
@abstractmethod
def metric_bag(self) -> MetricBag:
"""The training-related metrics."""
BatchT = TypeVar("BatchT")
[docs]
@final
class Trainer(StatefulObjectBag, Generic[BatchT]):
"""Trains a machine learning model."""
_model: Model
_unit: TrainUnit[BatchT]
_data_reader: DataReader[BatchT]
_gangs: Gangs
_dtype: DataType
_amp: bool
_optimizer: Optimizer
_lr_scheduler: LRScheduler
_loss_scaler: DynamicLossScaler
_max_gradient_norm: float | None
_step_nr: int
_max_num_steps: int | None
_data_epoch_nr: int
_max_num_data_epochs: int | None
_repeat_step: bool
_has_read_any_data: bool
_num_effective_batches: int
_end_of_data_epoch: bool
_end_of_data: bool
_should_stop: bool
_score_metric_descriptor: MetricDescriptor | None
_lower_better: bool
_early_stopper: EarlyStopper | None
_best_step_and_score: tuple[int, float] | None
_valid_score: float | None
_valid_units: Sequence[EvalUnit[BatchT]]
_valid_data_readers: Sequence[DataReader[BatchT]]
_validate_after_n_steps: int
_validate_every_n_steps: int | None
_validate_after_n_data_epochs: int
_validate_every_n_data_epochs: int | None
_checkpoint_manager: CheckpointManager
_checkpoint_after_n_steps: int
_checkpoint_every_n_steps: int | None
_checkpoint_after_n_data_epochs: int
_checkpoint_every_n_data_epochs: int | None
_keep_last_n_checkpoints: int | None
_keep_best_n_checkpoints: int | None
_keep_last_n_models: int | None
_keep_best_n_models: int | None
_metric_bag: MetricBag
_metric_recorder: MetricRecorder
_publish_metrics_after_n_steps: int
_publish_metrics_every_n_steps: int | None
_publish_metrics_after_n_data_epochs: int
_publish_metrics_every_n_data_epochs: int | None
_garbage_collector: GarbageCollector
_profiler: Profiler
_device_stat_tracker: DeviceStatTracker
_gradient_check: bool
_anomaly_detection: bool
_seed: int
_rng_bag: RngBag
_wall_watch: Stopwatch
_data_read_time: float
_elapsed_time: float
_run: bool
_progress_reporter: ProgressReporter
_progress_task: ProgressTask | None
def __init__(
self,
*,
unit: TrainUnit[BatchT],
data_reader: DataReader[BatchT],
gangs: Gangs,
dtype: DataType,
amp: bool,
optimizer: Optimizer,
lr_scheduler: LRScheduler,
checkpoint_manager: CheckpointManager,
metric_recorder: MetricRecorder,
garbage_collector: GarbageCollector,
profiler: Profiler,
device_stat_tracker: DeviceStatTracker,
seed: int,
wall_watch: Stopwatch,
fp16_loss_scale: tuple[float, float] = (128.0, 0.0001),
max_gradient_norm: float | None = None,
max_num_steps: int | None = None,
max_num_data_epochs: int | None = None,
score_metric_descriptor: MetricDescriptor | None = None,
lower_better: bool = False,
early_stopper: EarlyStopper | None = None,
valid_units: Sequence[EvalUnit[BatchT]] | None = None,
valid_data_readers: Sequence[DataReader[BatchT]] | None = None,
validate_after_n_steps: int = 0,
validate_every_n_steps: int | None = None,
validate_after_n_data_epochs: int = 0,
validate_every_n_data_epochs: int | None = None,
checkpoint_after_n_steps: int = 0,
checkpoint_every_n_steps: int | None = None,
checkpoint_after_n_data_epochs: int = 0,
checkpoint_every_n_data_epochs: int | None = None,
keep_last_n_checkpoints: int | None = None,
keep_best_n_checkpoints: int | None = None,
keep_last_n_models: int | None = None,
keep_best_n_models: int | None = None,
publish_metrics_after_n_steps: int = 0,
publish_metrics_every_n_steps: int | None = None,
publish_metrics_after_n_data_epochs: int = 0,
publish_metrics_every_n_data_epochs: int | None = None,
gradient_check: bool = False,
anomaly_detection: bool = False,
) -> None:
"""
:param unit:
The training unit.
:param data_reader:
The data reader for training.
:param gangs:
The gangs to train on.
:param optimizer:
The parameter optimizer.
:param checkpoint_manager:
The checkpoint manager.
:param wall_watch:
The stopwatch to track process wall-time.
:param dtype:
The data type of the model.
:param lr_scheduler:
The learning rate scheduler.
:param amp:
If ``True``, enables ``torch.amp``.
:param fp16_loss_scale:
The initial and minimum loss scale for fp16 training.
:param max_gradient_norm:
The maximum gradient norm. If ``None``, no clipping will be applied.
:param max_num_steps:
The maximum number of steps to train for.
:param max_num_data_epochs:
The maximum number of data epochs to train for.
:param score_metric_descriptor:
The descriptor of the metric to use for score calculation.
:param lower_better:
If ``True``, lower scores are considered better.
:param early_stopper:
The early-stopper callable.
:param valid_units:
The evaluation units for validating the model.
:param valid_data_readers:
The data readers corresponding to each unit in ``valid_units``.
:param validate_after_n_steps:
The number of steps after which to start validating the model.
:param validate_every_n_steps:
The step interval at which to validate the model.
:param validate_after_n_data_epochs:
The number of data epochs after which to start validating the model.
:param validate_every_n_data_epochs:
The data epoch interval at which to validate the model.
:param checkpoint_after_n_steps:
The number of steps after which to start checkpointing.
:param checkpoint_every_n_steps:
The step interval at which to checkpoint.
:param checkpoint_after_n_data_epochs:
The number of data epochs after which to start checkpointing.
:param checkpoint_every_n_data_epochs:
The data epoch interval at which to checkpoint.
:param keep_last_n_checkpoints:
The number of checkpoints to keep. If ``None``, none will be deleted.
:param keep_best_n_checkpoints:
The number of checkpoints to keep based on their validation score.
If ``None``, none will be deleted.
:param keep_last_n_models:
The number of checkpoint models to keep. Must be greater than or
equal to ``keep_last_n_checkpoints``.
:param keep_best_n_models:
The number of best checkpoint models to keep based on their
validation score. Must be greater than or equal to
``keep_best_n_checkpoints``.
:param metric_recorder:
The metric recorder.
:param publish_metrics_after_n_steps:
The number of steps after which to start publishing metrics.
:param publish_metrics_every_n_steps:
The step interval at which to publish metrics.
:param publish_metrics_after_n_data_epochs:
The number of data epochs after which to start publishing metrics.
:param publish_metrics_every_n_data_epochs:
The data epoch interval at which to publish metrics.
:param profile: The profiler.
:param anomaly_detection:
If ``True``, turns on anomaly detection feature in ``torch.autograd``.
:param seed:
The random number generator seed.
"""
super().__init__()
device = gangs.root.device
self.register_non_stateful("_model", unit.model)
self._unit = unit
self._data_reader = data_reader
self._gangs = gangs
self._dtype = dtype
self._amp = amp
self._optimizer = optimizer
self._lr_scheduler = lr_scheduler
fp16_init_scale, fp16_min_scale = fp16_loss_scale
self._loss_scaler = DynamicLossScaler(
optimizer,
gangs.root,
sharded=gangs.root.size != gangs.rdp.size,
init_scale=fp16_init_scale,
min_scale=fp16_min_scale,
gradient_accumulation=data_reader.num_accumulate,
enabled=dtype == torch.float16,
)
self._max_gradient_norm = max_gradient_norm
self.register_stateful("_step_nr", 0)
if max_num_steps is not None:
if max_num_steps <= 0:
raise ValueError("`max_num_steps` must be greater than zero.")
self._max_num_steps = max_num_steps
self.register_stateful("_data_epoch_nr", 1)
if max_num_data_epochs is not None:
if max_num_data_epochs <= 0:
raise ValueError("`max_num_data_epochs` must be greater than zero.")
self._max_num_data_epochs = max_num_data_epochs
self._repeat_step = False
self.register_stateful("_has_read_any_data", False)
self._num_effective_batches = 0
self._end_of_data_epoch = False
self._end_of_data = False
self._should_stop = False
self._score_metric_descriptor = score_metric_descriptor
self._lower_better = lower_better
if early_stopper is not None:
if score_metric_descriptor is None:
raise ValueError(
"`score_metric_descriptor` must be specified when `early_stopper` is specified."
)
if gangs.root.rank == 0:
self._early_stopper = early_stopper
else:
self._early_stopper = NoopEarlyStopper()
else:
self._early_stopper = None
self.register_stateful("_best_step_and_score", None)
self._valid_score = None
if valid_units is None and valid_data_readers is None:
self._valid_units = []
self._valid_data_readers = []
elif valid_units is not None and valid_data_readers is not None:
if len(valid_units) != len(valid_data_readers):
raise ValueError(
f"The number of data readers in `valid_data_readers` must match the number of units in `valid_units` ({len(valid_units)}), but is {len(valid_data_readers)} instead."
)
self._valid_units = valid_units
self._valid_data_readers = valid_data_readers
else:
raise ValueError(
"`valid_units` and `valid_data_readers` must be both specified."
)
if validate_every_n_steps is not None:
if validate_every_n_steps <= 0:
raise ValueError("`validate_every_n_steps` must be greater than zero.")
self._validate_after_n_steps = validate_after_n_steps
self._validate_every_n_steps = validate_every_n_steps
if validate_every_n_data_epochs is not None:
if validate_every_n_data_epochs <= 0:
raise ValueError(
"`validate_every_n_data_epochs` must be greater than zero."
)
self._validate_after_n_data_epochs = validate_after_n_data_epochs
self._validate_every_n_data_epochs = validate_every_n_data_epochs
self._checkpoint_manager = checkpoint_manager
if checkpoint_every_n_steps is not None:
if checkpoint_every_n_steps <= 0:
raise ValueError(
"`checkpoint_every_n_steps` must be greater than zero."
)
self._checkpoint_after_n_steps = checkpoint_after_n_steps
self._checkpoint_every_n_steps = checkpoint_every_n_steps
if checkpoint_every_n_data_epochs is not None:
if checkpoint_every_n_data_epochs <= 0:
raise ValueError(
"`checkpoint_every_n_data_epochs` must be greater than zero."
)
self._checkpoint_after_n_data_epochs = checkpoint_after_n_data_epochs
self._checkpoint_every_n_data_epochs = checkpoint_every_n_data_epochs
if keep_last_n_checkpoints is not None:
if keep_best_n_checkpoints is not None:
raise ValueError(
"`keep_last_n_checkpoints` and `keep_best_n_checkpoints` must not be specified at the same time."
)
if keep_last_n_checkpoints <= 0:
raise ValueError("`keep_last_n_checkpoints` must be greater than zero.")
elif keep_best_n_checkpoints is not None:
if keep_best_n_checkpoints <= 0:
raise ValueError("`keep_best_n_checkpoints` must be greater than zero.")
if checkpoint_every_n_steps is not None:
if score_metric_descriptor is None:
raise ValueError(
"`score_metric_descriptor` must be specified when `keep_best_n_checkpoints` is specified."
)
if validate_every_n_steps is None:
raise ValueError(
"`validate_every_n_steps` must be specified when `keep_best_n_checkpoints` is specified."
)
if checkpoint_every_n_steps % validate_every_n_steps != 0:
raise ValueError(
f"`checkpoint_every_n_steps` must be a multiple of `validate_every_n_steps` ({validate_every_n_steps}) when `keep_best_n_checkpoints` is specified, but is {checkpoint_every_n_steps} instead."
)
self._keep_last_n_checkpoints = keep_last_n_checkpoints
self._keep_best_n_checkpoints = keep_best_n_checkpoints
if keep_last_n_models is not None:
if keep_last_n_checkpoints is None:
raise ValueError(
"`keep_last_n_models` must not be specified when `keep_last_n_checkpoints` is not specified."
)
if keep_last_n_checkpoints > keep_last_n_models:
raise ValueError(
f"`keep_last_n_models` must be greater than or equal to `keep_last_n_checkpoints` ({keep_last_n_checkpoints}), but is {keep_last_n_models} instead."
)
if keep_best_n_models is not None:
if keep_best_n_checkpoints is None:
raise ValueError(
"`keep_best_n_models` must not be specified when `keep_best_n_checkpoints` is not specified."
)
if keep_best_n_checkpoints > keep_best_n_models:
raise ValueError(
f"`keep_best_n_models` must be greater than or equal to `keep_best_n_checkpoints` ({keep_best_n_checkpoints}), but is {keep_best_n_models} instead."
)
self._keep_last_n_models = keep_last_n_models
self._keep_best_n_models = keep_best_n_models
unit.metric_bag.register_metric(
"gradient_norm", Mean(device=device), persistent=False
)
self._metric_bag = unit.metric_bag
self._metric_recorder = metric_recorder
if publish_metrics_every_n_steps == 0:
raise ValueError(
"`publish_metrics_every_n_steps` must be greater than zero."
)
self._publish_metrics_after_n_steps = publish_metrics_after_n_steps
self._publish_metrics_every_n_steps = publish_metrics_every_n_steps
if publish_metrics_every_n_data_epochs == 0:
raise ValueError(
"`publish_metrics_every_n_data_epochs` must be greater than zero."
)
self._publish_metrics_after_n_data_epochs = publish_metrics_after_n_data_epochs
self._publish_metrics_every_n_data_epochs = publish_metrics_every_n_data_epochs
self._garbage_collector = garbage_collector
self._profiler = profiler
self._device_stat_tracker = device_stat_tracker
self._gradient_check = gradient_check
self._anomaly_detection = anomaly_detection
self._seed = seed
self._rng_bag = RngBag.from_device_defaults(CPU, device)
self._wall_watch = wall_watch
self._data_read_time = 0.0
self._elapsed_time = 0.0
self._run = False
self._progress_reporter = NoopProgressReporter()
self._progress_task = None
[docs]
def request_stop(self) -> None:
"""Request a graceful stop of the training."""
log.info("Stopping training after a final validation and saving checkpoint.")
self._should_stop = True
def __call__(self, progress_reporter: ProgressReporter | None = None) -> None:
if self._run:
raise InvalidOperationError("The trainer can only be run once.")
self._run = True
if progress_reporter is not None:
self._progress_reporter = progress_reporter
self._rng_bag.manual_seed(self._seed + self._gangs.root.rank)
try:
self._maybe_restore_state()
except KeyboardInterrupt:
log.info("Training terminated!")
raise
log.info("Running training on {} device(s).", self._gangs.root.size)
try:
self._do_run()
except KeyboardInterrupt:
log.info("Training terminated at step {}!", self._step_nr)
raise
finally:
self._garbage_collector.enable(False)
self._gangs.close()
if self._should_stop:
log.info("Training stopped at step {}!", self._step_nr)
return
elapsed_time = self._wall_watch.get_elapsed_time()
log.info("Training complete in {:,} seconds after {} step(s)!", int(elapsed_time), self._step_nr) # fmt: skip
def _maybe_restore_state(self) -> None:
step_nr = self._checkpoint_manager.get_last_step_number()
if step_nr is None:
return
self._step_nr = step_nr
log.info("Restoring training from the last checkpoint at step {}.", step_nr) # fmt: skip
state_dict = self._checkpoint_manager.load_checkpoint(step_nr)
self.load_state_dict(state_dict)
self._gangs.root.barrier()
log.info("Training restored. Resuming.")
def _do_run(self) -> None:
self._model.module.train()
self._garbage_collector.enable()
with self._progress_reporter, self._profiler:
self._progress_task = self._progress_reporter.create_task(
"train", total=self._max_num_steps, completed=self._step_nr
)
self._device_stat_tracker.reset()
first_iter = True
while self._should_run_step():
self._maybe_advance_data_epoch()
self._step_nr += 1
self._progress_task.step(1)
detect_anomaly = torch.autograd.set_detect_anomaly( # type: ignore[attr-defined]
self._anomaly_detection, check_nan=True
)
with detect_anomaly:
with record_function(f"step_{self._step_nr}"):
self._run_step()
if self._should_publish_metrics():
self._publish_metrics()
if self._should_validate():
self._validate()
self._maybe_request_early_stop()
if self._should_checkpoint():
self._checkpoint()
self._profiler.step()
self._garbage_collector.step()
self._valid_score = None
if first_iter:
# Emptying the CUDA memory allocator cache after the first
# iteration can reduce fragmentation and avoid OOM.
if self._gangs.root.device.type == "cuda":
torch.cuda.empty_cache()
first_iter = False
def _should_run_step(self) -> bool:
if self._end_of_data or self._should_stop:
return False
if self._max_num_steps is None:
return True
return self._step_nr < self._max_num_steps
def _maybe_advance_data_epoch(self) -> None:
if self._end_of_data_epoch:
self._data_epoch_nr += 1
self._end_of_data_epoch = False
def _run_step(self) -> None:
step_nr = self._step_nr
log.debug("{} training step {}.", "Repeating" if self._repeat_step else "Running", step_nr) # fmt: skip
watch = Stopwatch(start=True, device=self._gangs.root.device)
# Collect the batches.
with record_function(f"step_{step_nr}_data_load"):
batches = self._next_batches()
if batches is None:
return
# Prepare the unit.
if not self._repeat_step:
with record_function(f"step_{step_nr}_unit_setup"):
self._unit.set_step_nr(step_nr)
num_targets = 0
if self._loss_scaler.is_enabled:
self._metric_bag.begin_updates()
# Accumulate.
for batch_nr, batch in enumerate(batches):
with self._maybe_no_sync(batch_nr, len(batches)):
with record_function(f"step_{step_nr}_{batch_nr}_forward"):
batch_loss, num_batch_targets = self._compute_loss(batch)
if num_batch_targets is not None:
if num_batch_targets == 0:
raise ContractError(
"The train unit returned zero loss targets."
)
num_targets += num_batch_targets
with record_function(f"step_{step_nr}_{batch_nr}_backward"):
self._loss_scaler.backward(batch_loss)
# Normalize.
if num_targets > 0:
normalize_gradients(
self._model.module, self._gangs.dp, num_targets=num_targets
)
# Clip.
with record_function(f"step_{step_nr}_grad_norm"):
self._loss_scaler.unscale_gradients_()
# TODO(balioglu): Support tensor parallelism!
grad_norm = self._model.clip_gradient_norm(self._max_gradient_norm)
if self._gradient_check:
# Sanity check.
if not check_gradient_norms(grad_norm, self._gangs.dp, step_nr):
raise FloatingPointError(
f"The gradients are inconsistent between processes at step {step_nr}. Training cannot continue."
)
# Update the parameters.
with record_function(f"step_{step_nr}_optimizer"):
_, scale_result = self._loss_scaler.run_optimizer_step(step_nr)
self._repeat_step = scale_result.overflow
if self._repeat_step:
self._metric_bag.rollback_updates()
if scale_result.min_reached:
raise FloatingPointError(
f"The gradients are scaled down to minimum at step {step_nr}. Training cannot continue."
)
self._step_nr -= 1
if self._progress_task is None:
raise InternalError("`_progress_task` is `None`.")
self._progress_task.step(-1)
else:
self._lr_scheduler.step()
if self._loss_scaler.is_enabled:
self._metric_bag.commit_updates()
self._metric_bag.gradient_norm.update(grad_norm)
self._num_effective_batches += 1
# Reset the grads.
self._optimizer.zero_grad(set_to_none=True)
self._elapsed_time += watch.get_elapsed_time()
def _next_batches(self) -> list[BatchT] | None:
watch = Stopwatch(start=True)
try:
batches = next(self._data_reader)
except StopIteration:
batches = None
self._data_read_time += watch.get_elapsed_time()
if batches is not None:
self._has_read_any_data = True
return batches
self._data_reader.reset()
self._end_of_data_epoch = True
log.info("End of epoch {} reached at training step {}.", self._data_epoch_nr, self._step_nr) # fmt: skip
if not self._has_read_any_data:
self._end_of_data = True
elif self._max_num_data_epochs is not None:
if self._data_epoch_nr >= self._max_num_data_epochs:
self._end_of_data = True
if self._end_of_data:
log.info("End of data reached.", self._step_nr)
else:
self._repeat_step = True
self._step_nr -= 1
if self._progress_task is None:
raise InternalError("`_progress_task` is `None`.")
self._progress_task.step(-1)
return None
def _maybe_no_sync(self, batch_nr: int, num_batches: int) -> ContextManager:
if batch_nr < num_batches - 1:
return self._model.no_sync()
return nullcontext()
def _compute_loss(self, batch: BatchT) -> tuple[Tensor, int | None]:
with self._maybe_autocast():
return self._unit(batch)
def _maybe_autocast(self) -> ContextManager:
if self._dtype == torch.float32 or not self._amp:
return nullcontext()
return torch.autocast(device_type=self._gangs.dp.device.type, dtype=self._dtype)
def _should_publish_metrics(self) -> bool:
return self._should_do(
self._publish_metrics_after_n_steps,
self._publish_metrics_every_n_steps,
self._publish_metrics_after_n_data_epochs,
self._publish_metrics_every_n_data_epochs,
)
def _publish_metrics(self) -> None:
log.debug("Syncing metrics.")
if self._gangs.tp.rank == 0:
values = self._metric_bag.sync_and_compute_metrics()
else:
values = None
self._metric_bag.reset_non_persistent_metrics()
if self._gangs.root.rank == 0:
if values is None:
raise InternalError("`values` is `None`.")
extend_batch_metrics(
values, self._num_effective_batches, self._elapsed_time
)
device_stats = self._device_stat_tracker.get_stats()
values.update(device_stats)
values["lr"] = get_effective_lr(self._lr_scheduler)
values["data_epoch"] = self._data_epoch_nr
values["data_read_time"] = self._data_read_time
values["elapsed_time"] = self._elapsed_time
values["wall_time"] = self._wall_watch.get_elapsed_time()
self._metric_recorder.record_metrics("train", values, self._step_nr)
self._num_effective_batches = 0
self._data_read_time = 0.0
self._elapsed_time = 0.0
self._device_stat_tracker.reset()
self._gangs.root.barrier()
def _should_validate(self) -> bool:
if not self._valid_units:
return False
return self._should_do(
self._validate_after_n_steps,
self._validate_every_n_steps,
self._validate_after_n_data_epochs,
self._validate_every_n_data_epochs,
)
def _validate(self) -> None:
log.info("Starting validation after step {}.", self._step_nr)
self._model.module.eval()
with self._model.summon_full_parameters():
unit_scores = []
for unit, data_reader in zip(self._valid_units, self._valid_data_readers):
if unit.display_name:
log.info("Validating {}.", unit.display_name)
unit_score = self._validate_unit(unit, data_reader)
if unit_score is not None:
unit_scores.append(unit_score)
self._valid_score = self._compute_valid_score(unit_scores)
self._model.module.train()
log.info("Validation complete.")
@torch.inference_mode()
def _validate_unit(
self, unit: EvalUnit[BatchT], data_reader: DataReader[BatchT]
) -> float | None:
watch = Stopwatch(start=True, device=self._gangs.root.device)
unit.set_step_nr(self._step_nr)
task = self._progress_reporter.create_task("valid", total=None)
num_effective_batches = 0
for step_nr in count(start=1):
task.step(1)
log.debug("Running validation step {}.", step_nr)
try:
batches = next(data_reader)
except StopIteration:
break
for batch in batches:
with self._maybe_autocast():
unit(batch)
num_effective_batches += 1
task.close()
data_reader.reset()
metric_values = self._publish_validation_metrics(
unit, num_effective_batches, watch.get_elapsed_time()
)
return self._get_unit_score(metric_values)
def _publish_validation_metrics(
self, unit: EvalUnit[BatchT], num_batches: int, elapsed_time: float
) -> dict[str, object] | None:
log.debug("Syncing validation metrics.")
if self._gangs.tp.rank == 0:
values = unit.metric_bag.sync_and_compute_metrics()
else:
values = None
unit.metric_bag.reset_metrics()
if self._gangs.root.rank == 0:
if values is None:
raise InternalError("`values` is `None`.")
extend_batch_metrics(values, num_batches, elapsed_time)
values["data_epoch"] = self._data_epoch_nr
values["elapsed_time"] = elapsed_time
values["wall_time"] = self._wall_watch.get_elapsed_time()
if unit.display_name:
run_name = "valid/" + unit.display_name
else:
run_name = "valid"
self._metric_recorder.record_metrics(run_name, values, self._step_nr)
self._gangs.root.barrier()
return values
def _get_unit_score(self, metric_values: dict[str, object] | None) -> float | None:
if metric_values is None:
return None
if self._score_metric_descriptor is None:
return None
score = metric_values.get(self._score_metric_descriptor.name)
if score is None:
return None
if not isinstance(score, (int, float, Tensor)):
log.warning("The score metric must be of type `int`, `float`, or `torch.Tensor`.") # fmt: skip
return None
return float(score)
def _compute_valid_score(self, unit_scores: list[float]) -> float | None:
if self._score_metric_descriptor is None:
return None
if not unit_scores:
if self._gangs.root.rank == 0:
raise ContractError(
"None of the validation units returned a score metric value."
)
return None
last_score = mean(unit_scores)
def is_last_score_better() -> bool:
if self._best_step_and_score is None:
return True
best_score = self._best_step_and_score[1]
if self._lower_better:
return best_score > last_score
else:
return last_score > best_score
if is_last_score_better():
self._best_step_and_score = (self._step_nr, last_score)
if log.is_enabled_for_info():
best_step_nr, best_score = self._best_step_and_score # type: ignore[misc]
if len(unit_scores) > 1:
m1 = "Mean "
m2 = "Best Mean "
else:
m1 = ""
m2 = "Best "
v1 = self._score_metric_descriptor.formatter(last_score)
v2 = self._score_metric_descriptor.formatter(best_score)
s1 = f"{self._score_metric_descriptor.display_name}: {v1}"
s2 = f"{self._score_metric_descriptor.display_name}: {v2}"
log.info("Score (step {}) - {}{} | {}{} at step {}", self._step_nr, m1, s1, m2, s2, best_step_nr) # fmt: skip
return last_score
def _maybe_request_early_stop(self) -> None:
if self._early_stopper is None:
return
if self._gangs.root.rank == 0:
if self._valid_score is None:
raise InternalError("Early stopping, but `_valid_score` is `None`.")
should_stop = self._early_stopper.should_stop(
self._step_nr, self._valid_score
)
else:
should_stop = False
self._should_stop = broadcast_flag(self._gangs.root, should_stop)
if self._should_stop:
log.info("Early stop requested. Training will be terminated after saving checkpoint.") # fmt: skip
def _should_checkpoint(self) -> bool:
return self._should_do(
self._checkpoint_after_n_steps,
self._checkpoint_every_n_steps,
self._checkpoint_after_n_data_epochs,
self._checkpoint_every_n_data_epochs,
)
def _checkpoint(self) -> None:
step_nr = self._step_nr
log.info("Saving checkpoint after step {}.", step_nr)
state = self.state_dict()
self._checkpoint_manager.begin_checkpoint(step_nr)
log.info("Saving the trainer state.")
if self._gangs.dp.size > 1 and isinstance(self._model.module, DDP):
replicated_keys = {"_optimizer"}
else:
replicated_keys = None
self._checkpoint_manager.save_state(state, replicated_keys=replicated_keys)
log.info("Trainer state saved.")
log.info("Extracting the model state on rank 0.")
model_state = self._model.state_dict()
log.info("Model state extracted.")
log.info("Saving the model.")
self._checkpoint_manager.save_model_state_dict(model_state)
log.info("Model saved.")
if self._score_metric_descriptor is not None:
log.info("Saving the score.")
self._checkpoint_manager.save_score(self._valid_score, self._lower_better)
log.info("Score saved.")
self._checkpoint_manager.commit_checkpoint()
log.info("Checkpoint complete.")
# Clean up the checkpoints.
nc = self._keep_last_n_checkpoints
nm = self._keep_last_n_models
if nm is not None:
if nc is None:
raise InternalError("`_keep_last_n_checkpoints` is `None`")
self._checkpoint_manager.keep_last_n_checkpoints(nm)
self._checkpoint_manager.keep_last_n_checkpoints(nc, preserve_model=True)
elif nc is not None:
self._checkpoint_manager.keep_last_n_checkpoints(nc)
nc = self._keep_best_n_checkpoints
nm = self._keep_best_n_models
if nm is not None:
if nc is None:
raise InternalError("`_keep_best_n_checkpoints` is `None`")
self._checkpoint_manager.keep_best_n_checkpoints(nm)
self._checkpoint_manager.keep_best_n_checkpoints(nc, preserve_model=True)
elif nc is not None:
self._checkpoint_manager.keep_best_n_checkpoints(nc)
def _should_do(
self,
after_n_steps: int,
every_n_steps: int | None,
after_n_data_epochs: int,
every_n_data_epochs: int | None,
) -> bool:
should_do_at_step = self._should_do_at_step(after_n_steps, every_n_steps)
if self._end_of_data or self._should_stop:
if not self._has_read_any_data:
return False
return not should_do_at_step
if self._end_of_data_epoch and every_n_data_epochs is not None:
if self._data_epoch_nr >= after_n_data_epochs:
if self._data_epoch_nr % every_n_data_epochs == 0:
return not should_do_at_step
if self._repeat_step:
return False
return should_do_at_step
def _should_do_at_step(self, after_n_steps: int, every_n_steps: int | None) -> bool:
if self._max_num_steps is not None:
if self._step_nr >= self._max_num_steps:
return True
if every_n_steps is not None:
if self._step_nr >= after_n_steps:
return self._step_nr % every_n_steps == 0
return False