✎ Datasets¶
Overview¶
This tutorial demonstrates how to interact with pre-defined datasets in fairseq2. We use the gsm8k_sft
(generic instruction finetuning) dataset as an example.
Import all necessary modules¶
[1]:
from pathlib import Path
from fairseq2 import setup_fairseq2
from fairseq2.context import get_runtime_context
from fairseq2.datasets import Batching, LengthBatching, StaticBatching
from fairseq2.recipes.common import (
load_dataset,
load_text_tokenizer,
setup_gangs,
)
from fairseq2.recipes.config import DatasetSection, GangSection, ModelSection
from fairseq2.recipes.lm import InstructionFinetuneDatasetSection
from fairseq2.datasets.instruction import (
InstructionDataset,
InstructionReadOptions,
)
Initialization¶
We first need to initialize fairseq2 – setup_fairseq2()
. This will load the configuration and register the assets, which allows us to interact with pre-defined datasets and models.
Prerequisite: Follow the HuggingFace Datasets Tutorial to download the gsm8k data (formatted with fairseq2 flavor) to your local path (e.g.
/datasets/facebook/fairseq2-lm-gsm8k/
).
[1 example datapoint in the sft jsonl]
{
"src": "<|start_header_id|>user<|end_header_id|>\n\nBrittany got a 78 on her first test. After her second test, her average rose to an 81. What grade did she get on her second test?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"tgt": "First multiply her average grade by the number of tests she took to find the total number of points she scored: 81 points * 2 = <<81*2=162>>162 points\nThen subtract the number of points she scored on her first exam to find how many points she scored on her second exam: 162 points - 78 points = <<162-78=84>>84 points\n#### 84"
}
[2]:
# Setup fairseq2
setup_fairseq2()
context = get_runtime_context()
# Load the configuration
dataset_config = InstructionFinetuneDatasetSection(
name="gsm8k_sft", path=Path("/datasets/facebook/fairseq2-lm-gsm8k/sft")
)
Prepare the assets¶
We will load both the dataset and the model card. The retrieve_asset_card
function is used to load the asset card from the asset store.
[3]:
# prepare the seed
seed = 42
class Config(object):
"""
A configuration object for the dataset and model.
"""
def __init__(self, gang: GangSection, dataset: DatasetSection, model: ModelSection):
self.gang = gang
self.dataset = dataset
self.model = model
config = Config(
gang=GangSection(tensor_parallel_size=1),
dataset=dataset_config,
model=ModelSection(name="llama3_1_8b"),
)
gangs = setup_gangs(context, config)
dataset = load_dataset(InstructionDataset, context, config, gangs)
# load the tokenizer
tokenizer = load_text_tokenizer(context, config)
[4]:
# prepare the batching strategy
batching: Batching
if dataset_config.batch_size is not None:
batching = StaticBatching(dataset_config.batch_size)
else:
batching = LengthBatching(dataset_config.max_num_tokens)
# prepare the read options
read_options = InstructionReadOptions(
batching=batching,
example_shuffle_window=dataset_config.example_shuffle_window,
batch_shuffle_window=dataset_config.batch_shuffle_window,
num_prefetch=dataset_config.num_prefetch,
source_encode_mode=dataset_config.source_encode_mode,
target_encode_mode=dataset_config.target_encode_mode,
seed=seed,
)
Create Data Reader¶
To create a data reader, we need to prepare the gang and the batching options as well. If you dig into the create_reader
method, you will see that it implements the data pipeline that is covered in notebooks/data/datapipeline.ipynb
.
[5]:
data_reader = dataset.create_reader(
dataset_config.train_split,
tokenizer,
gangs.dp,
dataset_config.min_seq_len,
dataset_config.max_seq_len,
read_options,
)
Iterate over the batches¶
Now that we have the data reader, we can iterate over the batches.
[6]:
try:
batches = next(data_reader)
except StopIteration:
batches = None
if batches is not None:
for batch_nr, batch in enumerate(batches):
print(f"===batch_nr==={batch_nr}===")
print(batch)
print("")
else:
print("No more batches")
data_reader.reset()
===batch_nr===0===
SequenceBatch(seqs=tensor([[128000, 128006, 882, ..., 220, 10132, 128001],
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0],
...,
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0]],
device='cuda:0'), padding_mask=<fairseq2.nn.padding.PaddingMask object at 0x78220946d270>, target_mask=tensor([[False, False, False, ..., True, True, True],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]], device='cuda:0'), example={'id': [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], 'indices': {'is_ragged': True, 'seqs': tensor([[128000, 128006, 882, ..., 220, 10132, 128001],
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0],
...,
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,
327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,
319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,
319, 321, 320, 321, 316, 311])}, 'target_mask': {'is_ragged': True, 'seqs': tensor([[False, False, False, ..., True, True, True],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,
327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,
319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,
319, 321, 320, 321, 316, 311])}})