✎ Datasets¶
Overview¶
This tutorial demonstrates how to interact with pre-defined datasets in fairseq2. We use the gsm8k_sft
(generic instruction finetuning) dataset as an example.
Import all necessary modules¶
[1]:
from pathlib import Path
from fairseq2 import setup_fairseq2
from fairseq2.context import get_runtime_context
from fairseq2.datasets import Batching, LengthBatching, StaticBatching
from fairseq2.recipes.common import (
load_dataset,
load_text_tokenizer,
setup_gangs,
)
from fairseq2.recipes.config import GangSection
from fairseq2.recipes.lm import InstructionFinetuneDatasetSection
from fairseq2.datasets.instruction import (
InstructionDataset,
InstructionReadOptions,
)
Initialization¶
We first need to initialize fairseq2 – setup_fairseq2()
. This will load the configuration and register the assets, which allows us to interact with pre-defined datasets and models.
[2]:
# Setup fairseq2
setup_fairseq2()
context = get_runtime_context()
# Load the configuration
dataset_config = InstructionFinetuneDatasetSection()
dataset_config.name = "gsm8k_sft"
dataset_config.path = Path("/path/to/gsm8k_data/sft")
Prepare the assets¶
We will load both the dataset and the model card. The retrieve_asset_card
function is used to load the asset card from the asset store.
[ ]:
# prepare the seed
seed = 42
# prepare the gang
gangs = setup_gangs(context, GangSection(tensor_parallel_size=5))
dataset = load_dataset(InstructionDataset, context, dataset_config, gangs)
[ ]:
# prepare the batching strategy
batching: Batching
if dataset_config.batch_size is not None:
batching = StaticBatching(dataset_config.batch_size)
else:
batching = LengthBatching(dataset_config.max_num_tokens)
# prepare the read options
read_options = InstructionReadOptions(
batching=batching,
example_shuffle_window=dataset_config.example_shuffle_window,
batch_shuffle_window=dataset_config.batch_shuffle_window,
num_prefetch=dataset_config.num_prefetch,
source_encode_mode=dataset_config.source_encode_mode,
target_encode_mode=dataset_config.target_encode_mode,
seed=seed,
)
[ ]:
# load the tokenizer
tokenizer = load_text_tokenizer(context, "llama3_1_8b")
Create Data Reader¶
To create a data reader, we need to prepare the gang and the batching options as well. If you dig into the create_reader
method, you will see that it implements the data pipeline that is covered in notebooks/data/datapipeline.ipynb
.
[ ]:
data_reader = dataset.create_reader(
dataset_config.train_split,
tokenizer,
gangs.dp,
dataset_config.min_seq_len,
dataset_config.max_seq_len,
read_options,
)
Iterate over the batches¶
Now that we have the data reader, we can iterate over the batches.
[8]:
try:
batches = next(data_reader)
except StopIteration:
batches = None
if batches is not None:
for batch_nr, batch in enumerate(batches):
print(f"===batch_nr==={batch_nr}===")
print(batch)
print("")
else:
print("No more batches")
data_reader.reset()
===batch_nr===0===
SequenceBatch(seqs=tensor([[128000, 128006, 882, ..., 220, 10132, 128009],
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0],
...,
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0]]), padding_mask=<fairseq2.nn.padding.PaddingMask object at 0x7f6f630faf20>, target_mask=tensor([[False, False, False, ..., True, True, True],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]), example={'id': [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], 'indices': {'is_ragged': True, 'seqs': tensor([[128000, 128006, 882, ..., 220, 10132, 128009],
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0],
...,
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0],
[128000, 128006, 882, ..., 0, 0, 0]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,
327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,
319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,
319, 321, 320, 321, 316, 311])}, 'target_mask': {'is_ragged': True, 'seqs': tensor([[False, False, False, ..., True, True, True],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,
327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,
319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,
319, 321, 320, 321, 316, 311])}})