fairseq2.recipe.optim¶
This module provides helper functions to support the addition of new optimizers and learning rate schedulers in recipes.
Functions¶
- fairseq2.recipe.optim.prepare_parameter_groups(model: RecipeModel, group_configs: Sequence[ParameterGroupConfig]) Iterable[Tensor] | Iterable[dict[str, object]] [source]¶
Prepares the parameter groups to pass to an optimizer based on the specified model and group recipe configurations.
Returns an
Iterable
that can be passed as an argument to theparams
parameter of a PyTorchOptimizer
.Fields in group_configs whose value is set to
default
will use the default configuration in the corresponding top-level configuration. For instance, ifAdamWGroupConfig.betas
is set todefault
, the optimizer will use the value ofAdamWConfig.betas
.Note that the order of groups is important when determining which parameter belongs to which group. Each parameter is assigned to the first group in the list that matches its name; therefore, it is essential to list the groups in the correct order.
An example use ofprepare_parameter_groups
¶from collections.abc import Sequence from dataclasses import dataclass, field from torch.optim import Optimizer from fairseq2.recipe import RecipeModel, TrainRecipe from fairseq2.recipe.component import register_component from fairseq2.recipe.config import Default, ParameterGroupConfig, default from fairseq2.recipe.optim import prepare_parameter_groups from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver @dataclass class MyOptimizerConfig: """The top-level recipe configuration of MyOptimizer.""" lr: float = 0.1 """The default top-level learning rate.""" betas: tuple[float, float] = (0.9, 0.99) """The default top-level beta values.""" groups: Sequence[MyOptimizerGroupConfig] = field(default_factory=list) """The configuration of individual parameter groups.""" @dataclass class MyOptimizerGroupConfig(ParameterGroupConfig): """The parameter group configuration of MyOptimizer.""" lr: float | Default = default """If specified, overrides the top-level value.""" betas: tuple[float, float] | Default = default """If specified, overrides the top-level value.""" class MyOptimizer(Optimizer): ... def create_my_optimizer( resolver: DependencyResolver, config: MyOptimizerConfig ) -> MyOptimizer: model = resolver.resolve(RecipeModel) # Converts group configurations to an iterable of parameter groups # that can be passed to an optimizer. parameters = prepare_parameter_groups(model, config.groups) # Initialize the optimizer with `parameters`. return MyOptimizer(parameters, config.lr, config.betas) class MyTrainRecipe(TrainRecipe): def register(self, container: DependencyContainer) -> None: register_component( container, Optimizer, name="my_optimizer", config_kls=MyOptimizerConfig, factory=create_my_optimizer, ) ...
- fairseq2.recipe.optim.maybe_raise_param_group_length_error(field: str, value: Sequence[object], num_param_groups: int) None [source]¶
Raises
ValidationError
if the length of a learning rate scheduler configuration field (i.e.len(value)
) does not match the number of optimizer parameter groups.- Raises:
ValidationError – If
len(value)
does not matchnum_param_groups
.
A basic use ofmaybe_raise_param_group_length_error
¶from torch.optim import Optimizer from fairseq2.recipe.config import MyleLRConfig from fairseq2.recipe.optim import maybe_raise_param_group_length_error def get_start_lr(config: MyleLRConfig, optimizer: Optimizer) -> list[float]: num_param_groups = len(optimizer.param_groups) start_lr: float | list[float] = config.start_lr if isinstance(start_lr, float): return [start_lr] * num_param_groups maybe_raise_param_group_length_error("start_lr", start_lr, num_param_groups) return start_lr