fairseq2.recipe.optim¶
This module provides helper functions to add support for new optimizers and learning rate schedulers in recipes.
- fairseq2.recipe.optim.maybe_raise_param_group_length_error(field, value, num_param_groups)[source]¶
A helper function that raises
ValidationError
if the length of a learning rate scheduler configuration field (len(value)
) does not match the number of optimizer parameter groups (num_param_groups
).- Parameters:
- Raises:
ValidationError – if
len(value)
does not matchnum_param_groups
.
A basic use ofmaybe_raise_param_group_length_error
¶from torch.optim import Optimizer from fairseq2.recipe.config import MyleLRConfig from fairseq2.recipe.optim import maybe_raise_param_group_length_error def get_start_lr(config: MyleLRConfig, optimizer: Optimizer) -> list[float]: num_param_groups = len(optimizer.param_groups) start_lr: float | list[float] = config.start_lr if isinstance(start_lr, float): return [start_lr] * num_param_groups maybe_raise_param_group_length_error("start_lr", start_lr, num_param_groups) return start_lr