{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ✎ Parquet Dataloader\n", "\n", "## Overview\n", "\n", "> Prerequisite: make sure that you have installed fairseq2 with `pip install fairseq2[arrow]`. This will install additional packages required for parquet dataloader (_e.g._ pyarrow, retrying, polars, xxhash)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pyarrow as pa\n", "import pyarrow.parquet as pq\n", "import pyarrow.compute as pc\n", "import tempfile\n", "from pathlib import Path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fragments Streaming" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from fairseq2.data.parquet.fragment_streaming import (\n", " FragmentStreamingConfig, ParquetFragmentStreamer\n", ")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " col1 col2 col3\n", "0 1 a 1.1\n", "1 2 b 2.2\n", "2 3 c 3.3\n", "3 4 d 4.4\n", "4 5 e 5.5\n", "[, ]\n", "pyarrow.Table\n", "col1: int64\n", "col2: string\n", "col3: double\n", "----\n", "col1: [[1,2,3,4,5]]\n", "col2: [[\"a\",\"b\",\"c\",\"d\",\"e\"]]\n", "col3: [[1.1,2.2,3.3,4.4,5.5]]\n" ] } ], "source": [ "table = pa.Table.from_pydict(\n", " {\n", " \"col1\": [1, 2, 3, 4, 5],\n", " \"col2\": [\"a\", \"b\", \"c\", \"d\", \"e\"],\n", " \"col3\": [1.1, 2.2, 3.3, 4.4, 5.5],\n", " }\n", ")\n", "\n", "# Create a temporary directory and file\n", "with tempfile.TemporaryDirectory() as temp_dir:\n", " file_path = Path(temp_dir) / \"test.parquet\"\n", "\n", " # Write the parquet file\n", " pq.write_table(table, file_path)\n", "\n", " # Simple configuration\n", " config = FragmentStreamingConfig(\n", " parquet_path=Path(temp_dir),\n", " nb_epochs=2,\n", " split_to_row_groups=False, # Set to False makes a fragment correspond to a file\n", " fragment_shuffle_window=100, # Shuffle within a window of 100 fragments\n", " seed=42, # For reproducibility\n", " )\n", "\n", " # Create the streamer\n", " streamer = ParquetFragmentStreamer(config=config)\n", " \n", " print(streamer.dataset.read().to_pandas())\n", "\n", " # Build a pipeline for a specific rank/world_size (for distributed training)\n", " fragment_pipeline = streamer.build_pipeline(rank=0, world_size=1).and_return()\n", " result = list(fragment_pipeline)\n", "\n", " # the result is a list of fragments that points to the physical data location from which the data will be loaded\n", " print(result)\n", " print(result[0].to_table())\n", " # the 2 epochs should produce the same result\n", " assert result[0].to_table().equals(result[1].to_table())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fragments Loading" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "# create a sample parquet dataset\n", "row_groups_size_distribution = [2, 1, 3]\n", "row_group_size = 10\n", "\n", "total_size = sum(row_groups_size_distribution) * row_group_size\n", "\n", "data = {\n", " \"cat\": [\n", " f\"cat_{j}\"\n", " for j, size in enumerate(row_groups_size_distribution)\n", " for _ in range(size * 10)\n", " ],\n", " \"id\": [f\"id_{i}\" for i in range(total_size)],\n", " \"seq\": [np.arange(i % 10 + 2) for i in range(total_size)],\n", "}\n", "table = pa.Table.from_pydict(data)\n", "\n", "tmp_dir = Path(tempfile.gettempdir()) / \"parquet_dataset_test\"\n", "tmp_parquet_ds_path = tmp_dir / \"test2\"\n", "\n", "pq.write_to_dataset(\n", " table,\n", " tmp_parquet_ds_path,\n", " partition_cols=[\"cat\"],\n", " **{\"row_group_size\": row_group_size},\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "pyarrow.Table\n", "cat: string\n", "id: string\n", "seq: list\n", " child 0, item: int64\n", "----\n", "cat: [[\"cat_0\",\"cat_0\",\"cat_0\",\"cat_0\",\"cat_0\",...,\"cat_2\",\"cat_2\",\"cat_2\",\"cat_2\",\"cat_2\"]]\n", "id: [[\"id_0\",\"id_1\",\"id_2\",\"id_3\",\"id_4\",...,\"id_55\",\"id_56\",\"id_57\",\"id_58\",\"id_59\"]]\n", "seq: [[[0,1],[0,1,2],...,[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9,10]]]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "table" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "'/tmp/parquet_dataset_test/test2/cat=cat_0':\n", "total 32\n", "-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 17:39 51cd53ffbcd34e5f97af1c8b0c256d54-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 16:08 6bffd35df261477b81016fa4c1c0769d-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 14:27 7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 16:10 a7ea9c7884904b8da48b4677260db50d-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 13:20 bff977f0052a45aebe8237477e9ec495-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 15:01 c28be1610eaf4cb79ec9c0e30e6123fe-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 15:55 d7d20b8a34424d50949bdac0fb3a0455-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 15:52 f61ea7b57b4b4d88ad2047e91b9d93ee-0.parquet\n", "\n", "'/tmp/parquet_dataset_test/test2/cat=cat_1':\n", "total 32\n", "-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 17:39 51cd53ffbcd34e5f97af1c8b0c256d54-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 16:08 6bffd35df261477b81016fa4c1c0769d-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 14:27 7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 16:10 a7ea9c7884904b8da48b4677260db50d-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 13:20 bff977f0052a45aebe8237477e9ec495-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 15:01 c28be1610eaf4cb79ec9c0e30e6123fe-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 15:55 d7d20b8a34424d50949bdac0fb3a0455-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 15:52 f61ea7b57b4b4d88ad2047e91b9d93ee-0.parquet\n", "\n", "'/tmp/parquet_dataset_test/test2/cat=cat_2':\n", "total 32\n", "-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 17:39 51cd53ffbcd34e5f97af1c8b0c256d54-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 16:08 6bffd35df261477b81016fa4c1c0769d-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 14:27 7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 16:10 a7ea9c7884904b8da48b4677260db50d-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 13:20 bff977f0052a45aebe8237477e9ec495-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 15:01 c28be1610eaf4cb79ec9c0e30e6123fe-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 15:55 d7d20b8a34424d50949bdac0fb3a0455-0.parquet\n", "-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 15:52 f61ea7b57b4b4d88ad2047e91b9d93ee-0.parquet\n" ] } ], "source": [ "!ls -l {tmp_parquet_ds_path}/*" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pyarrow.Table\n", "id: string\n", "seq: list\n", " child 0, element: int64\n", "__batch_index: int32\n", "__fragment_index: int32\n", "__filename: string\n", "cat: dictionary\n", "__row_groups_ids: int32\n", "__index_in_fragement: int32\n", "----\n", "id: [[\"id_2\",\"id_3\",\"id_4\",\"id_5\",\"id_6\",\"id_7\",\"id_8\",\"id_9\"]]\n", "seq: [[[0,1,2,3],[0,1,2,3,4],...,[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9,10]]]\n", "__batch_index: [[0,0,0,0,0,0,0,0]]\n", "__fragment_index: [[0,0,0,0,0,0,0,0]]\n", "__filename: [[\"/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\",\"/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\",\"/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\",\"/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\",\"/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\",\"/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\",\"/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\",\"/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet\"]]\n", "cat: [ -- dictionary:\n", "[\"cat_0\"] -- indices:\n", "[0,0,0,0,0,0,0,0]]\n", "__row_groups_ids: [[0,0,0,0,0,0,0,0]]\n", "__index_in_fragement: [[2,3,4,5,6,7,8,9]]\n" ] } ], "source": [ "from dataclasses import dataclass, field\n", "from typing import List\n", "from fairseq2.data.parquet.fragment_loading import (\n", " FragmentLoadingConfig, NamedColumns, ParquetFragmentLoader\n", ")\n", "\n", "@dataclass\n", "class MyColumns(NamedColumns):\n", " # Format: new_name: original_column_name\n", " category: str = \"cat\"\n", " uid: str = \"id\"\n", " extra_columns: List[str] = field(default_factory=lambda: [\"seq\"])\n", "\n", "fragment_config = FragmentStreamingConfig(\n", " parquet_path=tmp_parquet_ds_path,\n", " nb_epochs=2,\n", " seed=42,\n", " split_to_row_groups=True,\n", " fragment_shuffle_window=10,\n", " files_circular_shift=False,\n", " partition_filters=[\n", " 'pc.field(\"cat\") == \"cat_0\"', # uncomment this line to see the effect of partition filters\n", " ]\n", ")\n", "\n", "streamer = ParquetFragmentStreamer(config=fragment_config)\n", "\n", "# Create the loading config\n", "loading_config = FragmentLoadingConfig(\n", " columns=MyColumns(),\n", " cache=False,\n", " rename_columns=False,\n", " add_fragment_traces=True, # Add tracking columns\n", " drop_null=True, # Drop rows with null values\n", " nb_prefetch=2, # Prefetch 2 fragments\n", " num_parallel_fragments=4, # Process 4 fragments in parallel\n", " filters=\"pc.greater_equal(pc.list_value_length(pc.field('seq')), 4)\", # you can comment this line out to see the effect of filters\n", ")\n", "\n", "# Build the loading pipeline\n", "loader = ParquetFragmentLoader(config=loading_config)\n", "\n", "fragment_pipeline = streamer.build_pipeline(0, 1)\n", "loading_pipeline = loader.apply(fragment_pipeline)\n", "\n", "result = list(iter(loading_pipeline.and_return()))\n", "\n", "print(result[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Table Bucketing" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_4\",\"id_6\",\"id_9\",\"id_15\",\"id_13\",...,\"id_6\",\"id_19\",\"id_17\",\"id_14\",\"id_8\"]]\n", "32\n", "==========\n", "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_8\",\"id_18\",\"id_5\",\"id_17\",\"id_7\",...,\"id_13\",\"id_2\",\"id_3\",\"id_9\",\"id_4\"]]\n", "16\n", "==========\n", "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_16\",\"id_12\",\"id_18\",\"id_16\",\"id_6\",...,\"id_8\",\"id_4\",\"id_2\",\"id_16\",\"id_9\"]]\n", "32\n", "==========\n", "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_9\",\"id_4\",\"id_5\",\"id_2\",\"id_8\",...,\"id_18\",\"id_12\",\"id_8\",\"id_9\",\"id_7\"]]\n", "32\n", "==========\n", "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_15\",\"id_5\",\"id_13\",\"id_4\",\"id_18\",...,\"id_14\",\"id_19\",\"id_6\",\"id_13\",\"id_5\"]]\n", "16\n", "==========\n", "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_3\",\"id_15\",\"id_15\",\"id_7\",\"id_17\",...,\"id_17\",\"id_16\",\"id_14\",\"id_19\",\"id_18\"]]\n", "32\n", "==========\n", "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_12\",\"id_15\",\"id_2\",\"id_18\",\"id_17\",...,\"id_3\",\"id_12\",\"id_15\",\"id_7\",\"id_9\"]]\n", "16\n", "==========\n", "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_18\",\"id_19\",\"id_4\",\"id_8\",\"id_19\",...,\"id_14\",\"id_5\",\"id_8\",\"id_13\",\"id_13\"]]\n", "32\n", "==========\n", "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_4\",\"id_18\",\"id_7\",\"id_9\",\"id_18\",...,\"id_4\",\"id_14\",\"id_7\",\"id_13\",\"id_3\"]]\n", "32\n", "==========\n", "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_5\",\"id_13\",\"id_2\",\"id_6\",\"id_16\",...,\"id_18\",\"id_9\",\"id_7\",\"id_15\",\"id_3\"]]\n", "16\n", "==========\n" ] } ], "source": [ "from fairseq2.data.parquet.table_bucketing import (\n", " TableBucketingConfig, TableBucketer\n", ")\n", "\n", "streamer = ParquetFragmentStreamer(config=fragment_config)\n", "loader = ParquetFragmentLoader(config=loading_config)\n", "\n", "fragment_pipeline = streamer.build_pipeline(0, 1)\n", "loading_pipeline = loader.apply(fragment_pipeline)\n", "\n", "# Create bucketing config\n", "bucketing_config = TableBucketingConfig(\n", " target_table_size=1000, # Aim for tables with 1000 rows\n", " min_fragment_number=2, # Combine at least 2 fragments\n", " max_fragment_number=10, # Combine at most 10 fragments\n", " shuffle=True, # Shuffle rows in memory\n", " batch_size=32 # Return batches of 32 rows\n", ")\n", "\n", "# Apply bucketing\n", "bucketer = TableBucketer(bucketing_config)\n", "final_pipeline = bucketer.apply(loading_pipeline).and_return()\n", "\n", "# Iterate through batches\n", "for batch in final_pipeline:\n", " # batch is a PyArrow Table\n", " print(batch.select([\"id\"]))\n", " print(len(batch))\n", " print(\"=\"*10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Complete Pipeline" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pyarrow.Table\n", "id: string\n", "----\n", "id: [[\"id_42\",\"id_41\",\"id_21\",\"id_36\"]]\n", "4\n", "==========\n" ] } ], "source": [ "from fairseq2.data.parquet import (\n", " BasicDataLoadingConfig,\n", " build_basic_parquet_data_pipeline,\n", " FragmentStreamingConfig,\n", " FragmentLoadingConfig,\n", " TableBucketingConfig\n", ")\n", "\n", "# Configure the entire pipeline\n", "config = BasicDataLoadingConfig(\n", " fragment_stream_config=FragmentStreamingConfig(\n", " parquet_path=tmp_parquet_ds_path,\n", " nb_epochs=None, # Infinite iterations\n", " fragment_shuffle_window=100\n", " ),\n", " fragment_load_config=FragmentLoadingConfig(\n", " nb_prefetch=2,\n", " num_parallel_fragments=3\n", " ),\n", " table_bucketing_config=TableBucketingConfig(\n", " target_table_size=1000,\n", " min_fragment_number=2,\n", " max_fragment_number=10,\n", " shuffle=True,\n", " batch_size=32\n", " ),\n", ")\n", "\n", "# Create the pipeline\n", "pipeline = build_basic_parquet_data_pipeline(config).and_return()\n", "\n", "# Use the pipeline\n", "for batch in pipeline:\n", " print(batch.select([\"id\"]))\n", " print(len(batch))\n", " print(\"=\"*10)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Working w/ PyArrow Tables\n", "\n", "PyArrow tables can be converted to various formats.\n", "- Using [polars](https://docs.pola.rs/), one can use `pl.from_arrow(pa_table, rechunk=False)` to convert into a polars dataframe (with almost memory zero copy);\n", "- `pa.Table.to_pylist()` or `pl.from_arrow(...).to_dicts()` (usually much faster) to convert into a list of dictionaries;\n", "- `parquet/utiles.py:pyarrow_table_to_torch_dict` to convert pyarrow table into a dictionary of cpu torch tensors (best effort)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/fsx-checkpoints/yaoj/envs/fs2_nightly_pt25_cu121/conda/lib/python3.10/site-packages/pandas/core/algorithms.py:1743: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n", " return lib.map_infer(values, mapper, convert=convert)\n" ] } ], "source": [ "# Convert to pandas\n", "df = batch.to_pandas()\n", "\n", "# Convert to dictionary\n", "batch_dict = batch.to_pydict()\n", "\n", "# Convert to torch tensors\n", "from fairseq2.data.parquet.utils import pyarrow_table_to_torch_dict\n", "tensor_dict = pyarrow_table_to_torch_dict(batch)\n", "\n", "# Using Polars (fast with zero-copy)\n", "import polars as pl\n", "polars_df = pl.from_arrow(batch, rechunk=False)\n", "\n", "# Convert to list of dictionaries (rows)\n", "rows = batch.to_pylist()\n", "# Or using polars (usually much faster)\n", "rows = pl.from_arrow(batch, rechunk=False).to_dicts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transformation" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['id_4', 'id_0', 'id_33', 'id_12', 'id_29', 'id_49', 'id_4', 'id_30', 'id_2', 'id_0', 'id_17', 'id_3', 'id_8', 'id_34', 'id_5', 'id_0', 'id_18', 'id_32', 'id_4', 'id_23', 'id_3', 'id_31', 'id_13', 'id_47', 'id_1', 'id_35', 'id_33', 'id_16', 'id_38', 'id_19', 'id_20', 'id_7']\n" ] } ], "source": [ "from fairseq2.data.parquet.arrow_transform import filter_strings_by_length\n", "\n", "# Create a custom transformation\n", "def my_transform(table: pa.Table):\n", " # Apply filtering by text length\n", " table = filter_strings_by_length(table, \"id\", min_len=3, max_len=3)\n", " return table\n", "\n", "streamer = ParquetFragmentStreamer(config=fragment_config)\n", "loader = ParquetFragmentLoader(config=loading_config)\n", "\n", "fragment_pipeline = streamer.build_pipeline(0, 1)\n", "loading_pipeline = loader.apply(fragment_pipeline)\n", "\n", "# Apply the transformation\n", "final_pipeline = loading_pipeline.map(my_transform).and_return()\n", "\n", "# Use the pipeline\n", "for batch in pipeline:\n", " print(batch.to_pydict()[\"id\"])\n", " break" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }