{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ✎ Datasets\n", "\n", "## Overview\n", "\n", "This tutorial demonstrates how to interact with pre-defined datasets in fairseq2.\n", "We use the `gsm8k_sft` (generic instruction finetuning) dataset as an example." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import all necessary modules" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "from fairseq2 import setup_fairseq2\n", "from fairseq2.context import get_runtime_context\n", "from fairseq2.datasets import Batching, LengthBatching, StaticBatching\n", "from fairseq2.recipes.common import (\n", " load_dataset,\n", " load_text_tokenizer,\n", " setup_gangs,\n", ")\n", "from fairseq2.recipes.config import GangSection\n", "from fairseq2.recipes.lm import InstructionFinetuneDatasetSection\n", "from fairseq2.datasets.instruction import (\n", " InstructionDataset,\n", " InstructionReadOptions,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialization\n", "\n", "We first need to initialize fairseq2 -- `setup_fairseq2()`.\n", "This will load the configuration and register the assets, which allows us to interact with pre-defined datasets and models." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Setup fairseq2\n", "setup_fairseq2()\n", "\n", "context = get_runtime_context()\n", "\n", "# Load the configuration\n", "dataset_config = InstructionFinetuneDatasetSection()\n", "\n", "dataset_config.name = \"gsm8k_sft\"\n", "dataset_config.path = Path(\"/path/to/gsm8k_data/sft\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare the assets\n", "\n", "We will load both the dataset and the model card. The `retrieve_asset_card` function is used to load the asset card from the asset store." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# prepare the seed\n", "seed = 42\n", "\n", "# prepare the gang\n", "gangs = setup_gangs(context, GangSection(tensor_parallel_size=5))\n", "dataset = load_dataset(InstructionDataset, context, dataset_config, gangs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# prepare the batching strategy\n", "batching: Batching\n", "\n", "if dataset_config.batch_size is not None:\n", " batching = StaticBatching(dataset_config.batch_size)\n", "else:\n", " batching = LengthBatching(dataset_config.max_num_tokens)\n", "\n", "# prepare the read options\n", "read_options = InstructionReadOptions(\n", " batching=batching,\n", " example_shuffle_window=dataset_config.example_shuffle_window,\n", " batch_shuffle_window=dataset_config.batch_shuffle_window,\n", " num_prefetch=dataset_config.num_prefetch,\n", " source_encode_mode=dataset_config.source_encode_mode,\n", " target_encode_mode=dataset_config.target_encode_mode,\n", " seed=seed,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load the tokenizer\n", "tokenizer = load_text_tokenizer(context, \"llama3_1_8b\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Data Reader\n", "\n", "To create a data reader, we need to prepare the gang and the batching options as well.\n", "If you dig into the `create_reader` method, you will see that it implements the data pipeline that is covered in `notebooks/data/datapipeline.ipynb`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_reader = dataset.create_reader(\n", " dataset_config.train_split,\n", " tokenizer,\n", " gangs.dp,\n", " dataset_config.min_seq_len,\n", " dataset_config.max_seq_len,\n", " read_options,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Iterate over the batches\n", "\n", "Now that we have the data reader, we can iterate over the batches." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "===batch_nr===0===\n", "SequenceBatch(seqs=tensor([[128000, 128006, 882, ..., 220, 10132, 128009],\n", " [128000, 128006, 882, ..., 0, 0, 0],\n", " [128000, 128006, 882, ..., 0, 0, 0],\n", " ...,\n", " [128000, 128006, 882, ..., 0, 0, 0],\n", " [128000, 128006, 882, ..., 0, 0, 0],\n", " [128000, 128006, 882, ..., 0, 0, 0]]), padding_mask=, target_mask=tensor([[False, False, False, ..., True, True, True],\n", " [False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False],\n", " ...,\n", " [False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False]]), example={'id': [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], 'indices': {'is_ragged': True, 'seqs': tensor([[128000, 128006, 882, ..., 220, 10132, 128009],\n", " [128000, 128006, 882, ..., 0, 0, 0],\n", " [128000, 128006, 882, ..., 0, 0, 0],\n", " ...,\n", " [128000, 128006, 882, ..., 0, 0, 0],\n", " [128000, 128006, 882, ..., 0, 0, 0],\n", " [128000, 128006, 882, ..., 0, 0, 0]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,\n", " 327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,\n", " 319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,\n", " 319, 321, 320, 321, 316, 311])}, 'target_mask': {'is_ragged': True, 'seqs': tensor([[False, False, False, ..., True, True, True],\n", " [False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False],\n", " ...,\n", " [False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False]]), 'seq_lens': tensor([338, 332, 329, 329, 330, 333, 334, 333, 331, 329, 328, 334, 324, 322,\n", " 327, 323, 324, 323, 322, 325, 322, 326, 326, 322, 327, 325, 325, 322,\n", " 319, 321, 318, 320, 321, 317, 321, 316, 316, 319, 318, 317, 320, 316,\n", " 319, 321, 320, 321, 316, 311])}})\n", "\n" ] } ], "source": [ "try:\n", " batches = next(data_reader)\n", "except StopIteration:\n", " batches = None\n", "\n", "if batches is not None:\n", " for batch_nr, batch in enumerate(batches):\n", " print(f\"===batch_nr==={batch_nr}===\")\n", " print(batch)\n", " print(\"\")\n", "else:\n", " print(\"No more batches\")\n", " data_reader.reset()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }