✎ Datasets

Overview

This tutorial demonstrates how to interact with pre-defined datasets in fairseq2. We use the gsm8k_sft (generic instruction finetuning) dataset as an example.

Import all necessary modules

[1]:
from pathlib import Path
from fairseq2 import setup_fairseq2
from fairseq2.context import get_runtime_context
from fairseq2.datasets import Batching, LengthBatching, StaticBatching
from fairseq2.recipes.common import (
    load_dataset,
    load_text_tokenizer,
    setup_gangs,
)
from fairseq2.recipes.config import GangSection
from fairseq2.recipes.lm import InstructionFinetuneDatasetSection
from fairseq2.datasets.instruction import (
    InstructionDataset,
    InstructionReadOptions,
)

Initialization

We first need to initialize fairseq2 – setup_fairseq2(). This will load the configuration and register the assets, which allows us to interact with pre-defined datasets and models.

[2]:
# Setup fairseq2
setup_fairseq2()

context = get_runtime_context()

# Load the configuration
dataset_config = InstructionFinetuneDatasetSection()

dataset_config.name = "gsm8k_sft"
dataset_config.path = Path("/path/to/gsm8k_data/sft")

Prepare the assets

We will load both the dataset and the model card. The retrieve_asset_card function is used to load the asset card from the asset store.

[ ]:
# prepare the seed
seed = 42

# prepare the gang
gangs = setup_gangs(context, GangSection(tensor_parallel_size=5))
dataset = load_dataset(InstructionDataset, context, dataset_config, gangs)
[ ]:
# prepare the batching strategy
batching: Batching

if dataset_config.batch_size is not None:
    batching = StaticBatching(dataset_config.batch_size)
else:
    batching = LengthBatching(dataset_config.max_num_tokens)

# prepare the read options
read_options = InstructionReadOptions(
    batching=batching,
    example_shuffle_window=dataset_config.example_shuffle_window,
    batch_shuffle_window=dataset_config.batch_shuffle_window,
    num_prefetch=dataset_config.num_prefetch,
    source_encode_mode=dataset_config.source_encode_mode,
    target_encode_mode=dataset_config.target_encode_mode,
    seed=seed,
)
[ ]:
# load the tokenizer
tokenizer = load_text_tokenizer(context, "llama3_1_8b")

Create Data Reader

To create a data reader, we need to prepare the gang and the batching options as well. If you dig into the create_reader method, you will see that it implements the data pipeline that is covered in notebooks/data/datapipeline.ipynb.

[ ]:
data_reader = dataset.create_reader(
    dataset_config.train_split,
    tokenizer,
    gangs.dp,
    dataset_config.min_seq_len,
    dataset_config.max_seq_len,
    read_options,
)

Iterate over the batches

Now that we have the data reader, we can iterate over the batches.

[8]:
try:
    batches = next(data_reader)
except StopIteration:
    batches = None

if batches is not None:
    for batch_nr, batch in enumerate(batches):
        print(f"===batch_nr==={batch_nr}===")
        print(batch)
        print("")
else:
    print("No more batches")
    data_reader.reset()
===batch_nr===0===
SequenceBatch(seqs=tensor([[128000, 128006,    882,  ...,    220,  10132, 128009],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        ...,
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0]]), padding_mask=<fairseq2.nn.padding.PaddingMask object at 0x7f6f630faf20>, target_mask=tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]]), example={'id': [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], 'indices': {'is_ragged': True, 'seqs': tensor([[128000, 128006,    882,  ...,    220,  10132, 128009],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        ...,
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,
        327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,
        319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,
        319, 321, 320, 321, 316, 311])}, 'target_mask': {'is_ragged': True, 'seqs': tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,
        327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,
        319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,
        319, 321, 320, 321, 316, 311])}})