fairseq2.recipe.optim

This module provides helper functions to add support for new optimizers and learning rate schedulers in recipes.

fairseq2.recipe.optim.maybe_raise_param_group_length_error(field, value, num_param_groups)[source]

A helper function that raises ValidationError if the length of a learning rate scheduler configuration field (len(value)) does not match the number of optimizer parameter groups (num_param_groups).

Parameters:
  • field (str) – The name of the configuration field that holds value.

  • value (Sequence[object]) – The value whose length to check.

  • num_param_groups (int) – The number of optimizer parameter groups.

Raises:

ValidationError – if len(value) does not match num_param_groups.

A basic use of maybe_raise_param_group_length_error
from torch.optim import Optimizer

from fairseq2.recipe.config import MyleLRConfig
from fairseq2.recipe.optim import maybe_raise_param_group_length_error

def get_start_lr(config: MyleLRConfig, optimizer: Optimizer) -> list[float]:
    num_param_groups = len(optimizer.param_groups)

    start_lr: float | list[float] = config.start_lr

    if isinstance(start_lr, float):
        return [start_lr] * num_param_groups

    maybe_raise_param_group_length_error("start_lr", start_lr, num_param_groups)

    return start_lr