{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ✎ Datasets\n",
"\n",
"## Overview\n",
"\n",
"This tutorial demonstrates how to interact with pre-defined datasets in fairseq2.\n",
"We use the `gsm8k_sft` (generic instruction finetuning) dataset as an example."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import all necessary modules"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from fairseq2 import setup_fairseq2\n",
"from fairseq2.context import get_runtime_context\n",
"from fairseq2.datasets import Batching, LengthBatching, StaticBatching\n",
"from fairseq2.recipes.common import (\n",
" load_dataset,\n",
" load_text_tokenizer,\n",
" setup_gangs,\n",
")\n",
"from fairseq2.recipes.config import DatasetSection, GangSection, ModelSection\n",
"from fairseq2.recipes.lm import InstructionFinetuneDatasetSection\n",
"from fairseq2.datasets.instruction import (\n",
" InstructionDataset,\n",
" InstructionReadOptions,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialization\n",
"\n",
"We first need to initialize fairseq2 -- `setup_fairseq2()`.\n",
"This will load the configuration and register the assets, which allows us to interact with pre-defined datasets and models.\n",
"\n",
"> Prerequisite: Follow the [HuggingFace Datasets Tutorial](https://huggingface.co/docs/hub/en/datasets-downloading) to download the [gsm8k data](https://huggingface.co/datasets/facebook/fairseq2-lm-gsm8k) (formatted with fairseq2 flavor) to your local path (_e.g._ `/datasets/facebook/fairseq2-lm-gsm8k/`).\n",
"\n",
"\n",
"[1 example datapoint in the sft jsonl]
\n",
"\n",
"```json\n",
"{\n",
" \"src\": \"<|start_header_id|>user<|end_header_id|>\\n\\nBrittany got a 78 on her first test. After her second test, her average rose to an 81. What grade did she get on her second test?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n",
" \"tgt\": \"First multiply her average grade by the number of tests she took to find the total number of points she scored: 81 points * 2 = <<81*2=162>>162 points\\nThen subtract the number of points she scored on her first exam to find how many points she scored on her second exam: 162 points - 78 points = <<162-78=84>>84 points\\n#### 84\"\n",
"}\n",
"```\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Setup fairseq2\n",
"setup_fairseq2()\n",
"\n",
"context = get_runtime_context()\n",
"\n",
"# Load the configuration\n",
"dataset_config = InstructionFinetuneDatasetSection(\n",
" name=\"gsm8k_sft\", path=Path(\"/datasets/facebook/fairseq2-lm-gsm8k/sft\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare the assets\n",
"\n",
"We will load both the dataset and the model card. The `retrieve_asset_card` function is used to load the asset card from the asset store."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# prepare the seed\n",
"seed = 42\n",
"\n",
"\n",
"class Config(object):\n",
" \"\"\"\n",
" A configuration object for the dataset and model.\n",
" \"\"\"\n",
"\n",
" def __init__(self, gang: GangSection, dataset: DatasetSection, model: ModelSection):\n",
" self.gang = gang\n",
" self.dataset = dataset\n",
" self.model = model\n",
"\n",
"\n",
"config = Config(\n",
" gang=GangSection(tensor_parallel_size=1),\n",
" dataset=dataset_config,\n",
" model=ModelSection(name=\"llama3_1_8b\"),\n",
")\n",
"gangs = setup_gangs(context, config)\n",
"dataset = load_dataset(InstructionDataset, context, config, gangs)\n",
"# load the tokenizer\n",
"tokenizer = load_text_tokenizer(context, config)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# prepare the batching strategy\n",
"batching: Batching\n",
"\n",
"if dataset_config.batch_size is not None:\n",
" batching = StaticBatching(dataset_config.batch_size)\n",
"else:\n",
" batching = LengthBatching(dataset_config.max_num_tokens)\n",
"\n",
"# prepare the read options\n",
"read_options = InstructionReadOptions(\n",
" batching=batching,\n",
" example_shuffle_window=dataset_config.example_shuffle_window,\n",
" batch_shuffle_window=dataset_config.batch_shuffle_window,\n",
" num_prefetch=dataset_config.num_prefetch,\n",
" source_encode_mode=dataset_config.source_encode_mode,\n",
" target_encode_mode=dataset_config.target_encode_mode,\n",
" seed=seed,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create Data Reader\n",
"\n",
"To create a data reader, we need to prepare the gang and the batching options as well.\n",
"If you dig into the `create_reader` method, you will see that it implements the data pipeline that is covered in `notebooks/data/datapipeline.ipynb`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"data_reader = dataset.create_reader(\n",
" dataset_config.train_split,\n",
" tokenizer,\n",
" gangs.dp,\n",
" dataset_config.min_seq_len,\n",
" dataset_config.max_seq_len,\n",
" read_options,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Iterate over the batches\n",
"\n",
"Now that we have the data reader, we can iterate over the batches."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"===batch_nr===0===\n",
"SequenceBatch(seqs=tensor([[128000, 128006, 882, ..., 220, 10132, 128001],\n",
" [128000, 128006, 882, ..., 0, 0, 0],\n",
" [128000, 128006, 882, ..., 0, 0, 0],\n",
" ...,\n",
" [128000, 128006, 882, ..., 0, 0, 0],\n",
" [128000, 128006, 882, ..., 0, 0, 0],\n",
" [128000, 128006, 882, ..., 0, 0, 0]],\n",
" device='cuda:0'), padding_mask=, target_mask=tensor([[False, False, False, ..., True, True, True],\n",
" [False, False, False, ..., False, False, False],\n",
" [False, False, False, ..., False, False, False],\n",
" ...,\n",
" [False, False, False, ..., False, False, False],\n",
" [False, False, False, ..., False, False, False],\n",
" [False, False, False, ..., False, False, False]], device='cuda:0'), example={'id': [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], 'indices': {'is_ragged': True, 'seqs': tensor([[128000, 128006, 882, ..., 220, 10132, 128001],\n",
" [128000, 128006, 882, ..., 0, 0, 0],\n",
" [128000, 128006, 882, ..., 0, 0, 0],\n",
" ...,\n",
" [128000, 128006, 882, ..., 0, 0, 0],\n",
" [128000, 128006, 882, ..., 0, 0, 0],\n",
" [128000, 128006, 882, ..., 0, 0, 0]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,\n",
" 327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,\n",
" 319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,\n",
" 319, 321, 320, 321, 316, 311])}, 'target_mask': {'is_ragged': True, 'seqs': tensor([[False, False, False, ..., True, True, True],\n",
" [False, False, False, ..., False, False, False],\n",
" [False, False, False, ..., False, False, False],\n",
" ...,\n",
" [False, False, False, ..., False, False, False],\n",
" [False, False, False, ..., False, False, False],\n",
" [False, False, False, ..., False, False, False]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,\n",
" 327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,\n",
" 319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,\n",
" 319, 321, 320, 321, 316, 311])}})\n",
"\n"
]
}
],
"source": [
"try:\n",
" batches = next(data_reader)\n",
"except StopIteration:\n",
" batches = None\n",
"\n",
"if batches is not None:\n",
" for batch_nr, batch in enumerate(batches):\n",
" print(f\"===batch_nr==={batch_nr}===\")\n",
" print(batch)\n",
" print(\"\")\n",
"else:\n",
" print(\"No more batches\")\n",
" data_reader.reset()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}