✎ Datasets

Overview

This tutorial demonstrates how to interact with pre-defined datasets in fairseq2. We use the gsm8k_sft (generic instruction finetuning) dataset as an example.

Import all necessary modules

[1]:
from pathlib import Path
from fairseq2 import setup_fairseq2
from fairseq2.context import get_runtime_context
from fairseq2.datasets import Batching, LengthBatching, StaticBatching
from fairseq2.recipes.common import (
    load_dataset,
    load_text_tokenizer,
    setup_gangs,
)
from fairseq2.recipes.config import DatasetSection, GangSection, ModelSection
from fairseq2.recipes.lm import InstructionFinetuneDatasetSection
from fairseq2.datasets.instruction import (
    InstructionDataset,
    InstructionReadOptions,
)

Initialization

We first need to initialize fairseq2 – setup_fairseq2(). This will load the configuration and register the assets, which allows us to interact with pre-defined datasets and models.

Prerequisite: Follow the HuggingFace Datasets Tutorial to download the gsm8k data (formatted with fairseq2 flavor) to your local path (e.g. /datasets/facebook/fairseq2-lm-gsm8k/).

[1 example datapoint in the sft jsonl]

{
    "src": "<|start_header_id|>user<|end_header_id|>\n\nBrittany got a 78 on her first test. After her second test, her average rose to an 81. What grade did she get on her second test?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
    "tgt": "First multiply her average grade by the number of tests she took to find the total number of points she scored: 81 points * 2 = <<81*2=162>>162 points\nThen subtract the number of points she scored on her first exam to find how many points she scored on her second exam: 162 points - 78 points = <<162-78=84>>84 points\n#### 84"
}
[2]:
# Setup fairseq2
setup_fairseq2()

context = get_runtime_context()

# Load the configuration
dataset_config = InstructionFinetuneDatasetSection(
    name="gsm8k_sft", path=Path("/datasets/facebook/fairseq2-lm-gsm8k/sft")
)

Prepare the assets

We will load both the dataset and the model card. The retrieve_asset_card function is used to load the asset card from the asset store.

[3]:
# prepare the seed
seed = 42


class Config(object):
    """
    A configuration object for the dataset and model.
    """

    def __init__(self, gang: GangSection, dataset: DatasetSection, model: ModelSection):
        self.gang = gang
        self.dataset = dataset
        self.model = model


config = Config(
    gang=GangSection(tensor_parallel_size=1),
    dataset=dataset_config,
    model=ModelSection(name="llama3_1_8b"),
)
gangs = setup_gangs(context, config)
dataset = load_dataset(InstructionDataset, context, config, gangs)
# load the tokenizer
tokenizer = load_text_tokenizer(context, config)
[4]:
# prepare the batching strategy
batching: Batching

if dataset_config.batch_size is not None:
    batching = StaticBatching(dataset_config.batch_size)
else:
    batching = LengthBatching(dataset_config.max_num_tokens)

# prepare the read options
read_options = InstructionReadOptions(
    batching=batching,
    example_shuffle_window=dataset_config.example_shuffle_window,
    batch_shuffle_window=dataset_config.batch_shuffle_window,
    num_prefetch=dataset_config.num_prefetch,
    source_encode_mode=dataset_config.source_encode_mode,
    target_encode_mode=dataset_config.target_encode_mode,
    seed=seed,
)

Create Data Reader

To create a data reader, we need to prepare the gang and the batching options as well. If you dig into the create_reader method, you will see that it implements the data pipeline that is covered in notebooks/data/datapipeline.ipynb.

[5]:
data_reader = dataset.create_reader(
    dataset_config.train_split,
    tokenizer,
    gangs.dp,
    dataset_config.min_seq_len,
    dataset_config.max_seq_len,
    read_options,
)

Iterate over the batches

Now that we have the data reader, we can iterate over the batches.

[6]:
try:
    batches = next(data_reader)
except StopIteration:
    batches = None

if batches is not None:
    for batch_nr, batch in enumerate(batches):
        print(f"===batch_nr==={batch_nr}===")
        print(batch)
        print("")
else:
    print("No more batches")
    data_reader.reset()
===batch_nr===0===
SequenceBatch(seqs=tensor([[128000, 128006,    882,  ...,    220,  10132, 128001],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        ...,
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0]],
       device='cuda:0'), padding_mask=<fairseq2.nn.padding.PaddingMask object at 0x78220946d270>, target_mask=tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]], device='cuda:0'), example={'id': [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], 'indices': {'is_ragged': True, 'seqs': tensor([[128000, 128006,    882,  ...,    220,  10132, 128001],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        ...,
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,
        327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,
        319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,
        319, 321, 320, 321, 316, 311])}, 'target_mask': {'is_ragged': True, 'seqs': tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,
        327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,
        319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,
        319, 321, 320, 321, 316, 311])}})