Preference Optimization¶
What you will learn
How to run preference optimization recipe
How to customize the criterion units and use DPO/CPO/ORPO/SimPO
How to run preference optimization with multiple nodes in fairseq2
Prerequisites
Get familiar with fairseq2 basics ( Overview)
Ensure you have fairseq2 installed ( Installation)
Get familiar with the tutorial on end-to-end fine-tuning ( End-to-End Fine-Tuning)
Overview¶
Prepare
Download the LLaMA3.1 8B instruction finetuned model from HuggingFace
Download the gsm8k data prepared for this tutorial
Fine-Tune
One simple command to run the preference optimization recipe
Choose between different preference optimization methods (DPO/CPO/ORPO/SimPO)
Customize configurations for each method
Accelerate the training with multiple nodes
Generate
One simple command to generate from the finetuned model
Since this step is similar to what has been covered in End-to-End Fine-Tuning, we would not elaborate on this topic.
Prepare¶
Model¶
Follow the HuggingFace Models Tutorial to download the LLaMA3.1 8B instruction finetuned model, which can be run on volta32gb GPUs.
Once you have the model in your local path, (e.g.` /models/Llama-3.1-8B/original/consolidated.00.pth
),
you need to register the model in a YAML card so that fairseq2 will know from where to pull the model
(read more about Assets). To do that:
Create a YAML file (e.g.
my_llama3_1_8b_instruct.yaml
) with the following content:
name: llama3_1_8b_instruct@user
checkpoint: "/models/Meta-Llama-3-8B-Instruct/original/consolidated.00.pth"
tokenizer: "/models/Meta-Llama-3-8B-Instruct/original/tokenizer.model"
Tip
The @user
specifies this is your special environment. This can also be extended to help resolve different domain name for your clusters
Save the file in one of the following locations:
Option 1: Place it in the default fairseq2 asset directory
mkdir -p ~/.config/fairseq2/assets
mv my_llama3_1_8b_instruct.yaml ~/.config/fairseq2/assets/
Option 2: Specify a custom directory and point
FAIRSEQ2_USER_ASSET_DIR
to itexport FAIRSEQ2_USER_ASSET_DIR=/path/to/custom/asset/directory
mv my_llama3_1_8b_instruct.yaml /path/to/custom/asset/directory/
You can check out the predefined fairseq2 LLaMA model cards here.
Dataset¶
Follow the HuggingFace Datasets Tutorial to download the gsm8k data, (formatted with fairseq2 flavor) to your local path (e.g. /datasets/facebook/fairseq2-lm-gsm8k/
).
We will use the dpo/train.jsonl
to fine-tune the model and use the test/test.jsonl
for evaluation.
Fine-Tune¶
One-Liner¶
Running the preference optimization recipe is as simple as:
fairseq2 lm preference_finetune $OUTPUT_DIR --config \
dataset.path=/datasets/facebook/fairseq2-lm-gsm8k/dpo \
model.name=llama3_1_8b_instruct \
trainer.dtype=float16 \
regime.num_steps=1000 \
regime.num_data_epochs=20 \
regime.checkpoint_every_n_steps=1000
By default, DPO (direct preference optimization) is applied (--config criterion.name=dpo
).
The use of other methods (CPO/ORPO/SimPO) is documented below.
The configuration fields are detailed in the page Recipes.
The fields follows a nested structure, where each field is a key-value pair.
In the example above, we have made changes to config sections including dataset
, model
, trainer
, regime
.
Dumping Configuration¶
For a quick overview of all the sections and fields, you can use the --dump-config
command:
fairseq2 lm preference_finetune --dump-config
Preference Optimization Methods¶
fairseq2 supports four different preference optimization methods:
DPO (Direct Preference Optimization)
- Key configuration parameters:
beta
: Coefficient of regularization towards reference model (default: 0.1)nll_scale
: Coefficient of NLL loss (default: 0.0)length_normalization
: Whether to use length normalized rewards (default: False)reference_model
: Name of reference model (default: llama3_1_8b_instruct)reference_dtype
: Data type of reference model (default: float16)
Example preset for DPO
Here’s an example preset config for DPO:
# dpo.yaml
model:
_set_:
name: llama3_1_8b_instruct
dataset:
_set_:
name: gsm8k_dpo
path: null
family: generic_preference
source_encode_mode: prompt
target_encode_mode: prompt_response
mask_source_tokens: true
min_seq_len: 1
max_seq_len: 8192
max_num_tokens: 16384
batch_size: null
example_shuffle_window: 10000
batch_shuffle_window: 1000
num_prefetch: 4
extras: {}
criterion:
_set_:
name: dpo
config:
reference_model:
name: llama3_1_8b_instruct
reference_dtype: bfloat16
beta: 0.1
nll_scale: 0.0
length_normalization: false
gang:
_set_:
tensor_parallel_size: 1
timeout: 15
high_priority: true
monitored: false
trainer:
_set_:
dtype: bfloat16
data_parallelism: fsdp
mixed_precision: static
gradient_accumulation: 1
activation_checkpointing: true
max_gradient_norm: null
fp16_loss_scale:
- 128.0
- 0.0001
torch_compile: false
profile: null
gradient_check: false
anomaly_detection: false
fsdp:
_set_:
version: v1
granularity: layer
hsdp: false
reshard_after_forward: true
fp32_reduce: true
optimizer:
_set_:
name: adamw
config:
_set_:
lr: 5.5e-06
betas:
- 0.9
- 0.95
eps: 1.0e-08
weight_decay: 0.1
amsgrad: false
maximize: false
capturable: false
differentiable: false
impl: auto
use_fp32: false
lr_scheduler:
_set_:
name: cosine_annealing
config:
_set_:
cycle_len: null
num_warmup_steps: 0
cycle_mul: 1.0
lr_mul: 1.0
start_lr: 0.0
final_lr: null
final_lr_scale: 0.2
regime:
_set_:
num_steps: 5000
num_data_epochs: null
score_metric: null
lower_score_better: false
validate_after_n_steps: 0
validate_every_n_steps: null
validate_after_n_data_epochs: 0
validate_every_n_data_epochs: null
checkpoint_after_n_steps: 0
checkpoint_every_n_steps: 1000
checkpoint_after_n_data_epochs: 0
checkpoint_every_n_data_epochs: null
keep_last_n_checkpoints: 1
keep_best_n_checkpoints: null
keep_last_n_models: null
keep_best_n_models: null
publish_metrics_after_n_steps: 0
publish_metrics_every_n_steps: 10
publish_metrics_after_n_data_epochs: 0
publish_metrics_every_n_data_epochs: null
common:
_set_:
seed: 2
metric_recorders:
log:
_set_:
enabled: true
jsonl:
_set_:
enabled: true
tensorboard:
_set_:
enabled: true
wandb:
_set_:
enabled: false
project: null
run: null
profilers:
torch:
_set_:
enabled: false
skip_n_steps: 4
wait_n_steps: 0
num_warmup_steps: 1
num_active_steps: 4
repeat: 1
assets:
_set_:
extra_path: null
checkpoint_dir: null
fairseq2 lm preference_finetune $OUTPUT_DIR --config-file /path/to/dpo.yaml
CPO (Contrastive Preference Optimization)
- Key configuration parameters:
beta
: Coefficient for preferred vs dispreferred sequences (default: 1.0)nll_scale
: Coefficient of NLL loss (default: 1.0)
Example preset for CPO
Here’s an example preset config for CPO:
# cpo.yaml
model:
_set_:
name: llama3_1_8b_instruct
dataset:
_set_:
path: /checkpoint/seamless/data/gsm8k_data/dpo
batch_size: 1
criterion:
_set_:
name: cpo
config:
beta: 0.1
nll_scale: 0.0
Then, to run the preference finetuning recipe with CPO unit:
fairseq2 lm preference_finetune $OUTPUT_DIR --config-file /path/to/cpo.yaml
ORPO (Odds Ratio Preference Optimization)
- Key configuration parameters:
orpo_lambda
: Coefficient of odds-ratio component (default: 1.0)nll_scale
: Coefficient of NLL loss (default: 1.0)
Example preset for ORPO
Here’s an example preset config for ORPO:
# orpo.yaml
model:
_set_:
name: llama3_1_8b_instruct
dataset:
_set_:
path: /checkpoint/seamless/data/gsm8k_data/dpo
batch_size: 1
criterion:
_set_:
name: orpo
config:
nll_scale: 0.0
orpo_lambda: 0.1
Then, to run the preference finetuning recipe with ORPO unit:
fairseq2 lm preference_finetune $OUTPUT_DIR --config-file /path/to/orpo.yaml
SimPO (Simple Preference Optimization)
- Key configuration parameters:
beta
: Coefficient of KL-divergence regularization (default: 1.0)gamma
: Target reward margin between completions (default: 0.5)nll_scale
: Coefficient of NLL loss (default: 0.0)
Example preset for SimPO
Here’s an example preset config for SimPO:
# simpo.yaml
model:
_set_:
name: llama3_1_8b_instruct
dataset:
_set_:
path: /checkpoint/seamless/data/gsm8k_data/dpo
batch_size: 1
criterion:
_set_:
name: simpo
config:
beta: 2
nll_scale: 0.0
Then, to run the preference finetuning recipe with SimPO unit:
fairseq2 lm preference_finetune $OUTPUT_DIR --config-file /path/to/simpo.yaml
Iterative Training¶
Sometimes you may want to continue fine-tuning from a previously trained checkpoint, either to:
Resume interrupted training
Fine-tune on additional data
Perform iterative fine-tuning with different hyperparameters
fairseq2 provides a clean way to handle this through the checkpoint system (learn more about Checkpoint Management):
fairseq2 lm preference_finetune $OUTPUT_DIR --config \
common.assets.checkpoint_dir=/path/to/checkpoint \
model.name=last_checkpoint \ # this will pick up the last checkpoint
dataset.path=/path/to/data
To pick up a specific checkpoint
CKPT_DIR="/checkpoint/user/experiments/run_0/checkpoints"
CKPT="checkpoint_step_1000" # e.g. checkpoint of step 1000
fairseq2 lm preference_finetune $OUTPUT_DIR --config \
common.assets.checkpoint_dir=$CKPT_DIR \
model.name=$CKPT \
dataset.path=/path/to/new/data \
dataset.max_num_tokens=4096 \
trainer.dtype=float16
Note
If you want to pick a specific checkpoint instead of the last checkpoint, the model
parameter must be set to checkpoint_step_X
where X matches the step number of the checkpoint you want to load.
Multi-Node¶
To help accelerate the training, fairseq2 is able to automatically detect multi-node setup.
Option 1: Slurm
srun --nodes=2 --ntasks-per-node=8 \ fairseq2 lm preference_finetune $OUTPUT_DIR \ ...
Option 2: Torchrun
torchrun --standalone --nproc-per-node 8 --no-python \ fairseq2 lm preference_finetune $OUTPUT_DIR \ ...
Generate¶
Once we have finished the training, we can find in the $OUTPUT_DIR
the model checkpoints in $OUTPUT_DIR/checkpoints
.
With that, we can now generate over the test dataset!
You can either use fairseq2 native generation recipe:
CKPT_DIR="/checkpoint/$USER/my_experiment/checkpoints"
CKPT="last_checkpoint"
SAVE_DIR="/checkpoint/$USER/my_experiment/generations"
DATASET="/datasets/facebook/fairseq2-lm-gsm8k/test/test.jsonl"
fairseq2 lm generate $SAVE_DIR --no-sweep-dir --config \
common.assets.checkpoint_dir=$CKPT_DIR \
model.name=$CKPT \
seq_generator.config.temperature=0.1 \
dataset.path=$DATASET
Or accelerate with VLLM:
from vllm import LLM
llm = LLM(
model=<path_to_fs2_checkpoint>, # path of your model
tokenizer=<name_or_path_of_hf_tokenizer>, # path of your tokenizer files
)
output = llm.generate("Hello, my name is")
print(output)
For the simplicity of our documentation, please refer to End-to-End Fine-Tuning for more details.