✎ Parquet Dataloader

Overview

Prerequisite: make sure that you have installed fairseq2 with pip install fairseq2[arrow]. This will install additional packages required for parquet dataloader (e.g. pyarrow, retrying, polars, xxhash).

[1]:
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.compute as pc
import tempfile
from pathlib import Path

Fragments Streaming

[2]:
from fairseq2.data.parquet.fragment_streaming import (
    FragmentStreamingConfig, ParquetFragmentStreamer
)
[3]:
table = pa.Table.from_pydict(
    {
        "col1": [1, 2, 3, 4, 5],
        "col2": ["a", "b", "c", "d", "e"],
        "col3": [1.1, 2.2, 3.3, 4.4, 5.5],
    }
)

# Create a temporary directory and file
with tempfile.TemporaryDirectory() as temp_dir:
    file_path = Path(temp_dir) / "test.parquet"

    # Write the parquet file
    pq.write_table(table, file_path)

    # Simple configuration
    config = FragmentStreamingConfig(
        parquet_path=Path(temp_dir),
        nb_epochs=2,
        split_to_row_groups=False,  # Set to False makes a fragment correspond to a file
        fragment_shuffle_window=100,  # Shuffle within a window of 100 fragments
        seed=42,  # For reproducibility
    )

    # Create the streamer
    streamer = ParquetFragmentStreamer(config=config)

    print(streamer.dataset.read().to_pandas())

    # Build a pipeline for a specific rank/world_size (for distributed training)
    fragment_pipeline = streamer.build_pipeline(rank=0, world_size=1).and_return()
    result = list(fragment_pipeline)

    # the result is a list of fragments that points to the physical data location from which the data will be loaded
    print(result)
    print(result[0].to_table())
    # the 2 epochs should produce the same result
    assert result[0].to_table().equals(result[1].to_table())
   col1 col2  col3
0     1    a   1.1
1     2    b   2.2
2     3    c   3.3
3     4    d   4.4
4     5    e   5.5
[<pyarrow.dataset.ParquetFileFragment path=/tmp/tmp7sxrjcuf/test.parquet>, <pyarrow.dataset.ParquetFileFragment path=/tmp/tmp7sxrjcuf/test.parquet>]
pyarrow.Table
col1: int64
col2: string
col3: double
----
col1: [[1,2,3,4,5]]
col2: [["a","b","c","d","e"]]
col3: [[1.1,2.2,3.3,4.4,5.5]]

Fragments Loading

[4]:
import numpy as np

# create a sample parquet dataset
row_groups_size_distribution = [2, 1, 3]
row_group_size = 10

total_size = sum(row_groups_size_distribution) * row_group_size

data = {
    "cat": [
        f"cat_{j}"
        for j, size in enumerate(row_groups_size_distribution)
        for _ in range(size * 10)
    ],
    "id": [f"id_{i}" for i in range(total_size)],
    "seq": [np.arange(i % 10 + 2) for i in range(total_size)],
}
table = pa.Table.from_pydict(data)

tmp_dir = Path(tempfile.gettempdir()) / "parquet_dataset_test"
tmp_parquet_ds_path = tmp_dir / "test2"

pq.write_to_dataset(
    table,
    tmp_parquet_ds_path,
    partition_cols=["cat"],
    **{"row_group_size": row_group_size},
)
[5]:
table
[5]:
pyarrow.Table
cat: string
id: string
seq: list<item: int64>
  child 0, item: int64
----
cat: [["cat_0","cat_0","cat_0","cat_0","cat_0",...,"cat_2","cat_2","cat_2","cat_2","cat_2"]]
id: [["id_0","id_1","id_2","id_3","id_4",...,"id_55","id_56","id_57","id_58","id_59"]]
seq: [[[0,1],[0,1,2],...,[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9,10]]]
[6]:
!ls -l {tmp_parquet_ds_path}/*
'/tmp/parquet_dataset_test/test2/cat=cat_0':
total 32
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 17:39 51cd53ffbcd34e5f97af1c8b0c256d54-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 16:08 6bffd35df261477b81016fa4c1c0769d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 14:27 7597b212cc43460ab26d29f0f7fd6ad8-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 16:10 a7ea9c7884904b8da48b4677260db50d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 13:20 bff977f0052a45aebe8237477e9ec495-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 15:01 c28be1610eaf4cb79ec9c0e30e6123fe-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 15:55 d7d20b8a34424d50949bdac0fb3a0455-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 15:52 f61ea7b57b4b4d88ad2047e91b9d93ee-0.parquet

'/tmp/parquet_dataset_test/test2/cat=cat_1':
total 32
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 17:39 51cd53ffbcd34e5f97af1c8b0c256d54-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 16:08 6bffd35df261477b81016fa4c1c0769d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 14:27 7597b212cc43460ab26d29f0f7fd6ad8-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 16:10 a7ea9c7884904b8da48b4677260db50d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 13:20 bff977f0052a45aebe8237477e9ec495-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 15:01 c28be1610eaf4cb79ec9c0e30e6123fe-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 15:55 d7d20b8a34424d50949bdac0fb3a0455-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 15:52 f61ea7b57b4b4d88ad2047e91b9d93ee-0.parquet

'/tmp/parquet_dataset_test/test2/cat=cat_2':
total 32
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 17:39 51cd53ffbcd34e5f97af1c8b0c256d54-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 16:08 6bffd35df261477b81016fa4c1c0769d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 14:27 7597b212cc43460ab26d29f0f7fd6ad8-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 16:10 a7ea9c7884904b8da48b4677260db50d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 13:20 bff977f0052a45aebe8237477e9ec495-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 15:01 c28be1610eaf4cb79ec9c0e30e6123fe-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 15:55 d7d20b8a34424d50949bdac0fb3a0455-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 15:52 f61ea7b57b4b4d88ad2047e91b9d93ee-0.parquet
[7]:
from dataclasses import dataclass, field
from typing import List
from fairseq2.data.parquet.fragment_loading import (
    FragmentLoadingConfig, NamedColumns, ParquetFragmentLoader
)

@dataclass
class MyColumns(NamedColumns):
    # Format: new_name: original_column_name
    category: str = "cat"
    uid: str = "id"
    extra_columns: List[str] = field(default_factory=lambda: ["seq"])

fragment_config = FragmentStreamingConfig(
    parquet_path=tmp_parquet_ds_path,
    nb_epochs=2,
    seed=42,
    split_to_row_groups=True,
    fragment_shuffle_window=10,
    files_circular_shift=False,
    partition_filters=[
        'pc.field("cat") == "cat_0"',  # uncomment this line to see the effect of partition filters
    ]
)

streamer = ParquetFragmentStreamer(config=fragment_config)

# Create the loading config
loading_config = FragmentLoadingConfig(
    columns=MyColumns(),
    cache=False,
    rename_columns=False,
    add_fragment_traces=True,  # Add tracking columns
    drop_null=True,  # Drop rows with null values
    nb_prefetch=2,  # Prefetch 2 fragments
    num_parallel_fragments=4,  # Process 4 fragments in parallel
    filters="pc.greater_equal(pc.list_value_length(pc.field('seq')), 4)",  # you can comment this line out to see the effect of filters
)

# Build the loading pipeline
loader = ParquetFragmentLoader(config=loading_config)

fragment_pipeline = streamer.build_pipeline(0, 1)
loading_pipeline = loader.apply(fragment_pipeline)

result = list(iter(loading_pipeline.and_return()))

print(result[0])
pyarrow.Table
id: string
seq: list<element: int64>
  child 0, element: int64
__batch_index: int32
__fragment_index: int32
__filename: string
cat: dictionary<values=string, indices=int32, ordered=0>
__row_groups_ids: int32
__index_in_fragement: int32
----
id: [["id_2","id_3","id_4","id_5","id_6","id_7","id_8","id_9"]]
seq: [[[0,1,2,3],[0,1,2,3,4],...,[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9,10]]]
__batch_index: [[0,0,0,0,0,0,0,0]]
__fragment_index: [[0,0,0,0,0,0,0,0]]
__filename: [["/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet"]]
cat: [  -- dictionary:
["cat_0"]  -- indices:
[0,0,0,0,0,0,0,0]]
__row_groups_ids: [[0,0,0,0,0,0,0,0]]
__index_in_fragement: [[2,3,4,5,6,7,8,9]]

Table Bucketing

[15]:
from fairseq2.data.parquet.table_bucketing import (
    TableBucketingConfig, TableBucketer
)

streamer = ParquetFragmentStreamer(config=fragment_config)
loader = ParquetFragmentLoader(config=loading_config)

fragment_pipeline = streamer.build_pipeline(0, 1)
loading_pipeline = loader.apply(fragment_pipeline)

# Create bucketing config
bucketing_config = TableBucketingConfig(
    target_table_size=1000,  # Aim for tables with 1000 rows
    min_fragment_number=2,   # Combine at least 2 fragments
    max_fragment_number=10,  # Combine at most 10 fragments
    shuffle=True,            # Shuffle rows in memory
    batch_size=32            # Return batches of 32 rows
)

# Apply bucketing
bucketer = TableBucketer(bucketing_config)
final_pipeline = bucketer.apply(loading_pipeline).and_return()

# Iterate through batches
for batch in final_pipeline:
    # batch is a PyArrow Table
    print(batch.select(["id"]))
    print(len(batch))
    print("="*10)
pyarrow.Table
id: string
----
id: [["id_4","id_6","id_9","id_15","id_13",...,"id_6","id_19","id_17","id_14","id_8"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_8","id_18","id_5","id_17","id_7",...,"id_13","id_2","id_3","id_9","id_4"]]
16
==========
pyarrow.Table
id: string
----
id: [["id_16","id_12","id_18","id_16","id_6",...,"id_8","id_4","id_2","id_16","id_9"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_9","id_4","id_5","id_2","id_8",...,"id_18","id_12","id_8","id_9","id_7"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_15","id_5","id_13","id_4","id_18",...,"id_14","id_19","id_6","id_13","id_5"]]
16
==========
pyarrow.Table
id: string
----
id: [["id_3","id_15","id_15","id_7","id_17",...,"id_17","id_16","id_14","id_19","id_18"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_12","id_15","id_2","id_18","id_17",...,"id_3","id_12","id_15","id_7","id_9"]]
16
==========
pyarrow.Table
id: string
----
id: [["id_18","id_19","id_4","id_8","id_19",...,"id_14","id_5","id_8","id_13","id_13"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_4","id_18","id_7","id_9","id_18",...,"id_4","id_14","id_7","id_13","id_3"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_5","id_13","id_2","id_6","id_16",...,"id_18","id_9","id_7","id_15","id_3"]]
16
==========

Complete Pipeline

[17]:
from fairseq2.data.parquet import (
    BasicDataLoadingConfig,
    build_basic_parquet_data_pipeline,
    FragmentStreamingConfig,
    FragmentLoadingConfig,
    TableBucketingConfig
)

# Configure the entire pipeline
config = BasicDataLoadingConfig(
    fragment_stream_config=FragmentStreamingConfig(
        parquet_path=tmp_parquet_ds_path,
        nb_epochs=None,  # Infinite iterations
        fragment_shuffle_window=100
    ),
    fragment_load_config=FragmentLoadingConfig(
        nb_prefetch=2,
        num_parallel_fragments=3
    ),
    table_bucketing_config=TableBucketingConfig(
        target_table_size=1000,
        min_fragment_number=2,
        max_fragment_number=10,
        shuffle=True,
        batch_size=32
    ),
)

# Create the pipeline
pipeline = build_basic_parquet_data_pipeline(config).and_return()

# Use the pipeline
for batch in pipeline:
    print(batch.select(["id"]))
    print(len(batch))
    print("="*10)
    break
pyarrow.Table
id: string
----
id: [["id_42","id_41","id_21","id_36"]]
4
==========

Working w/ PyArrow Tables

PyArrow tables can be converted to various formats.

  • Using polars, one can use pl.from_arrow(pa_table, rechunk=False) to convert into a polars dataframe (with almost memory zero copy);

  • pa.Table.to_pylist() or pl.from_arrow(...).to_dicts() (usually much faster) to convert into a list of dictionaries;

  • parquet/utiles.py:pyarrow_table_to_torch_dict to convert pyarrow table into a dictionary of cpu torch tensors (best effort)

[10]:
# Convert to pandas
df = batch.to_pandas()

# Convert to dictionary
batch_dict = batch.to_pydict()

# Convert to torch tensors
from fairseq2.data.parquet.utils import pyarrow_table_to_torch_dict
tensor_dict = pyarrow_table_to_torch_dict(batch)

# Using Polars (fast with zero-copy)
import polars as pl
polars_df = pl.from_arrow(batch, rechunk=False)

# Convert to list of dictionaries (rows)
rows = batch.to_pylist()
# Or using polars (usually much faster)
rows = pl.from_arrow(batch, rechunk=False).to_dicts()
/fsx-checkpoints/yaoj/envs/fs2_nightly_pt25_cu121/conda/lib/python3.10/site-packages/pandas/core/algorithms.py:1743: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return lib.map_infer(values, mapper, convert=convert)

Transformation

[11]:
from fairseq2.data.parquet.arrow_transform import filter_strings_by_length

# Create a custom transformation
def my_transform(table: pa.Table):
    # Apply filtering by text length
    table = filter_strings_by_length(table, "id", min_len=3, max_len=3)
    return table

streamer = ParquetFragmentStreamer(config=fragment_config)
loader = ParquetFragmentLoader(config=loading_config)

fragment_pipeline = streamer.build_pipeline(0, 1)
loading_pipeline = loader.apply(fragment_pipeline)

# Apply the transformation
final_pipeline = loading_pipeline.map(my_transform).and_return()

# Use the pipeline
for batch in pipeline:
    print(batch.to_pydict()["id"])
    break
['id_4', 'id_0', 'id_33', 'id_12', 'id_29', 'id_49', 'id_4', 'id_30', 'id_2', 'id_0', 'id_17', 'id_3', 'id_8', 'id_34', 'id_5', 'id_0', 'id_18', 'id_32', 'id_4', 'id_23', 'id_3', 'id_31', 'id_13', 'id_47', 'id_1', 'id_35', 'id_33', 'id_16', 'id_38', 'id_19', 'id_20', 'id_7']