Source code for fairseq2.recipes.lm._preference_finetune._common
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.from__future__importannotationsfromdataclassesimportdataclassimporttorchfromtorchimportTensorfromtorcheval.metricsimportMeanfromfairseq2.datasets.preferenceimportPreferenceBatchfromfairseq2.gangimportGangfromfairseq2.models.sequenceimportSequenceBatch,SequenceModelOutputfromfairseq2.recipes.metricsimportSequenceMetricBagdef_gather_lprobs(output:SequenceModelOutput,target:SequenceBatch)->Tensor:asserttarget.target_maskisnotNonelogprobs=torch.log_softmax(output.logits,dim=-1)chosen_logps=torch.gather(logprobs,-1,target.seqs.unsqueeze(-1)).squeeze(-1)chosen_logps=(chosen_logps*target.target_mask).sum(dim=-1)# [Batch, 1]returnchosen_logpsdef_gather_lprobs_avg(output:SequenceModelOutput,target:SequenceBatch)->tuple[Tensor,Tensor]:asserttarget.target_maskisnotNonelogprobs=torch.log_softmax(output.logits,dim=-1)per_token_logps=torch.gather(logprobs,-1,target.seqs.unsqueeze(-1)).squeeze(-1)total_logps=(per_token_logps*target.target_mask).sum(dim=-1)# [Batch, 1]asserttarget.target_maskisnotNoneaverage_logps=total_logps/target.target_mask.sum(-1)returntotal_logps,average_logps
[docs]@torch.inference_mode()defupdate_logps(self,batch:PreferenceBatch,chosen_logps:Tensor,rejected_logps:Tensor,)->None:"""Update the Chosen Sequence Log Probabilities and Rejected Sequence Log Probabilities metrics. :param batch: The batch processed by the model. :param chosen_logps: The log probabilities for each sequence in ``batch.chosen``. :param rejected_logps: The log probabilities for each sequence in ``batch.rejected``. """self.chosen_logps.update(chosen_logps.sum()/batch.chosen.batch_size,weight=batch.chosen.batch_size)self.rejected_logps.update(rejected_logps.sum()/batch.rejected.batch_size,weight=batch.rejected.batch_size,)
[docs]@torch.inference_mode()defupdate_sequence_lengths(self,batch:PreferenceBatch,)->None:"""Update the Chosen Sequence Length and Rejected Sequence Length metrics. :param batch: The batch processed by the model. """self.chosen_lengths.update(Tensor([batch.chosen.num_target_elements()/batch.chosen.batch_size]),weight=batch.chosen.batch_size,)self.rejected_lengths.update(Tensor([batch.rejected.num_target_elements()/batch.rejected.batch_size]),weight=batch.rejected.batch_size,)