Source code for fairseq2.recipes.lm._preference_finetune._cpo

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Final, cast, final

import torch
import torch.distributed
from torch import Tensor
from typing_extensions import override

from fairseq2.datasets.preference import PreferenceBatch
from fairseq2.gang import Gang, Gangs
from fairseq2.metrics import Mean
from fairseq2.models.sequence import SequenceModelOutput, as_auto_regressive_input
from fairseq2.recipes import Model, TrainUnit
from fairseq2.recipes.config import get_config_section
from fairseq2.recipes.lm._preference_finetune._common import (
    POCriterionSection,
    POFinetuneMetricBag,
    _gather_lprobs,
)
from fairseq2.recipes.lm._preference_finetune._handler import POFinetuneUnitHandler
from fairseq2.utils.structured import structure
from fairseq2.utils.validation import validate


[docs] @final class CpoFinetuneUnit(TrainUnit[PreferenceBatch]): """Represents the language model CPO-finetuning unit. Paper: https://arxiv.org/abs/2401.08417.""" _model: Model _beta: float _nll_scale: float _metric_bag: CpoFinetuneMetricBag def __init__( self, model: Model, gangs: Gangs, beta: float = 1.0, nll_scale: float = 1.0, ) -> None: self._model = model self._beta = beta self._nll_scale = nll_scale self._metric_bag = CpoFinetuneMetricBag(gangs.dp) @override def __call__(self, batch: PreferenceBatch) -> tuple[Tensor, int]: chosen_batch = batch.chosen chosen_input_batch, chosen_target_batch = as_auto_regressive_input(chosen_batch) rejected_batch = batch.rejected rejected_input_batch, rejected_target_batch = as_auto_regressive_input( rejected_batch ) chosen_output = cast( SequenceModelOutput, self._model.module(chosen_input_batch) ) rejected_output = cast( SequenceModelOutput, self._model.module(rejected_input_batch) ) chosen_logps = _gather_lprobs(chosen_output, chosen_target_batch) rejected_logps = _gather_lprobs(rejected_output, rejected_target_batch) cpo_loss = self._compute_cpo_loss(chosen_logps, rejected_logps) nll_loss = chosen_output.compute_loss( chosen_target_batch.seqs, loss_mask=chosen_target_batch.target_mask ) self._metric_bag.update_cpo_loss(batch, cpo_loss) self._metric_bag.update_nll_loss(chosen_batch, nll_loss) self._metric_bag.update_sequence_lengths(batch) self._metric_bag.update_logps(batch, chosen_logps, rejected_logps) self._metric_bag.update_batch_metrics(chosen_batch) loss = ( cpo_loss + self._nll_scale * nll_loss * chosen_target_batch.batch_size / chosen_target_batch.num_target_elements() ) # normalization applied locally per-rank return loss, chosen_target_batch.batch_size def _compute_cpo_loss( self, chosen_logps: Tensor, rejected_logps: Tensor, ) -> Tensor: cpo_loss = -torch.nn.functional.logsigmoid( self._beta * (chosen_logps - rejected_logps) ) return cpo_loss.sum() @property @override def model(self) -> Model: return self._model @property @override def metric_bag(self) -> CpoFinetuneMetricBag: return self._metric_bag
[docs] class CpoFinetuneMetricBag(POFinetuneMetricBag): """Holds the metrics of a CPO preference finetuning task.""" cpo_loss: Mean def __init__(self, gang: Gang) -> None: super().__init__(gang) self.register_metric("cpo_loss", Mean(device=gang.device), persistent=False)
[docs] @torch.inference_mode() def update_cpo_loss(self, batch: PreferenceBatch, loss: Tensor) -> None: """Update the CPO loss metric. :param batch: The batch processed by the model. :param loss: The CPO loss of ``batch``. """ self.cpo_loss.update( loss / batch.chosen.batch_size, weight=batch.chosen.batch_size )
CPO_FINETUNE_UNIT: Final = "cpo"
[docs] @dataclass(kw_only=True) class CpoFinetuneConfig: beta: float = 1.0 """The coefficient applied to the difference between preferred and dispreferred sequences.""" nll_scale: float = 1.0 """The coefficient of NLL loss added to the CPO loss."""
[docs] @final class CpoFinetuneUnitHandler(POFinetuneUnitHandler): @override def create( self, model: Model, gangs: Gangs, recipe_config: object ) -> TrainUnit[PreferenceBatch]: criterion_section = get_config_section( recipe_config, "criterion", POCriterionSection ) config = structure(criterion_section.config, CpoFinetuneConfig) validate(config) return CpoFinetuneUnit(model, gangs, config.beta, config.nll_scale) @property @override def name(self) -> str: return CPO_FINETUNE_UNIT @property @override def config_kls(self) -> type[object]: return CpoFinetuneConfig