fairseq2.recipe.optim

This module provides helper functions to support the addition of new optimizers and learning rate schedulers in recipes.

Functions

fairseq2.recipe.optim.prepare_parameter_groups(model: RecipeModel, group_configs: Sequence[ParameterGroupConfig]) Iterable[Tensor] | Iterable[dict[str, object]][source]

Prepares the parameter groups to pass to an optimizer based on the specified model and group recipe configurations.

Returns an Iterable that can be passed as an argument to the params parameter of a PyTorch Optimizer.

Fields in group_configs whose value is set to default will use the default configuration in the corresponding top-level configuration. For instance, if AdamWGroupConfig.betas is set to default, the optimizer will use the value of AdamWConfig.betas.

Note that the order of groups is important when determining which parameter belongs to which group. Each parameter is assigned to the first group in the list that matches its name; therefore, it is essential to list the groups in the correct order.

An example use of prepare_parameter_groups
from collections.abc import Sequence
from dataclasses import dataclass, field

from torch.optim import Optimizer

from fairseq2.recipe import RecipeModel, TrainRecipe
from fairseq2.recipe.component import register_component
from fairseq2.recipe.config import Default, ParameterGroupConfig, default
from fairseq2.recipe.optim import prepare_parameter_groups
from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver

@dataclass
class MyOptimizerConfig:
    """The top-level recipe configuration of MyOptimizer."""

    lr: float = 0.1
    """The default top-level learning rate."""

    betas: tuple[float, float] = (0.9, 0.99)
    """The default top-level beta values."""

    groups: Sequence[MyOptimizerGroupConfig] = field(default_factory=list)
    """The configuration of individual parameter groups."""


@dataclass
class MyOptimizerGroupConfig(ParameterGroupConfig):
    """The parameter group configuration of MyOptimizer."""

    lr: float | Default = default
    """If specified, overrides the top-level value."""

    betas: tuple[float, float] | Default = default
    """If specified, overrides the top-level value."""


class MyOptimizer(Optimizer):
    ...


def create_my_optimizer(
    resolver: DependencyResolver, config: MyOptimizerConfig
) -> MyOptimizer:
    model = resolver.resolve(RecipeModel)

    # Converts group configurations to an iterable of parameter groups
    # that can be passed to an optimizer.
    parameters = prepare_parameter_groups(model, config.groups)

    # Initialize the optimizer with `parameters`.
    return MyOptimizer(parameters, config.lr, config.betas)


class MyTrainRecipe(TrainRecipe):
    def register(self, container: DependencyContainer) -> None:
        register_component(
            container,
            Optimizer,
            name="my_optimizer",
            config_kls=MyOptimizerConfig,
            factory=create_my_optimizer,
        )

    ...
fairseq2.recipe.optim.maybe_raise_param_group_length_error(field: str, value: Sequence[object], num_param_groups: int) None[source]

Raises ValidationError if the length of a learning rate scheduler configuration field (i.e. len(value)) does not match the number of optimizer parameter groups.

Raises:

ValidationError – If len(value) does not match num_param_groups.

A basic use of maybe_raise_param_group_length_error
from torch.optim import Optimizer

from fairseq2.recipe.config import MyleLRConfig
from fairseq2.recipe.optim import maybe_raise_param_group_length_error

def get_start_lr(config: MyleLRConfig, optimizer: Optimizer) -> list[float]:
    num_param_groups = len(optimizer.param_groups)

    start_lr: float | list[float] = config.start_lr

    if isinstance(start_lr, float):
        return [start_lr] * num_param_groups

    maybe_raise_param_group_length_error("start_lr", start_lr, num_param_groups)

    return start_lr