✎ Parquet Dataloader¶
Overview¶
Prerequisite: make sure that you have installed fairseq2 with
pip install fairseq2[arrow]
. This will install additional packages required for parquet dataloader (e.g. pyarrow, retrying, polars, xxhash).
[1]:
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.compute as pc
import tempfile
from pathlib import Path
Fragments Streaming¶
[2]:
from fairseq2.data.parquet.fragment_streaming import (
FragmentStreamingConfig, ParquetFragmentStreamer
)
[3]:
table = pa.Table.from_pydict(
{
"col1": [1, 2, 3, 4, 5],
"col2": ["a", "b", "c", "d", "e"],
"col3": [1.1, 2.2, 3.3, 4.4, 5.5],
}
)
# Create a temporary directory and file
with tempfile.TemporaryDirectory() as temp_dir:
file_path = Path(temp_dir) / "test.parquet"
# Write the parquet file
pq.write_table(table, file_path)
# Simple configuration
config = FragmentStreamingConfig(
parquet_path=Path(temp_dir),
nb_epochs=2,
split_to_row_groups=False, # Set to False makes a fragment correspond to a file
fragment_shuffle_window=100, # Shuffle within a window of 100 fragments
seed=42, # For reproducibility
)
# Create the streamer
streamer = ParquetFragmentStreamer(config=config)
print(streamer.dataset.read().to_pandas())
# Build a pipeline for a specific rank/world_size (for distributed training)
fragment_pipeline = streamer.build_pipeline(rank=0, world_size=1).and_return()
result = list(fragment_pipeline)
# the result is a list of fragments that points to the physical data location from which the data will be loaded
print(result)
print(result[0].to_table())
# the 2 epochs should produce the same result
assert result[0].to_table().equals(result[1].to_table())
col1 col2 col3
0 1 a 1.1
1 2 b 2.2
2 3 c 3.3
3 4 d 4.4
4 5 e 5.5
[<pyarrow.dataset.ParquetFileFragment path=/tmp/tmp7sxrjcuf/test.parquet>, <pyarrow.dataset.ParquetFileFragment path=/tmp/tmp7sxrjcuf/test.parquet>]
pyarrow.Table
col1: int64
col2: string
col3: double
----
col1: [[1,2,3,4,5]]
col2: [["a","b","c","d","e"]]
col3: [[1.1,2.2,3.3,4.4,5.5]]
Fragments Loading¶
[4]:
import numpy as np
# create a sample parquet dataset
row_groups_size_distribution = [2, 1, 3]
row_group_size = 10
total_size = sum(row_groups_size_distribution) * row_group_size
data = {
"cat": [
f"cat_{j}"
for j, size in enumerate(row_groups_size_distribution)
for _ in range(size * 10)
],
"id": [f"id_{i}" for i in range(total_size)],
"seq": [np.arange(i % 10 + 2) for i in range(total_size)],
}
table = pa.Table.from_pydict(data)
tmp_dir = Path(tempfile.gettempdir()) / "parquet_dataset_test"
tmp_parquet_ds_path = tmp_dir / "test2"
pq.write_to_dataset(
table,
tmp_parquet_ds_path,
partition_cols=["cat"],
**{"row_group_size": row_group_size},
)
[5]:
table
[5]:
pyarrow.Table
cat: string
id: string
seq: list<item: int64>
child 0, item: int64
----
cat: [["cat_0","cat_0","cat_0","cat_0","cat_0",...,"cat_2","cat_2","cat_2","cat_2","cat_2"]]
id: [["id_0","id_1","id_2","id_3","id_4",...,"id_55","id_56","id_57","id_58","id_59"]]
seq: [[[0,1],[0,1,2],...,[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9,10]]]
[6]:
!ls -l {tmp_parquet_ds_path}/*
'/tmp/parquet_dataset_test/test2/cat=cat_0':
total 32
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 17:39 51cd53ffbcd34e5f97af1c8b0c256d54-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 16:08 6bffd35df261477b81016fa4c1c0769d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 14:27 7597b212cc43460ab26d29f0f7fd6ad8-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 16:10 a7ea9c7884904b8da48b4677260db50d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 13:20 bff977f0052a45aebe8237477e9ec495-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 15:01 c28be1610eaf4cb79ec9c0e30e6123fe-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 15:55 d7d20b8a34424d50949bdac0fb3a0455-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1438 Mar 25 15:52 f61ea7b57b4b4d88ad2047e91b9d93ee-0.parquet
'/tmp/parquet_dataset_test/test2/cat=cat_1':
total 32
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 17:39 51cd53ffbcd34e5f97af1c8b0c256d54-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 16:08 6bffd35df261477b81016fa4c1c0769d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 14:27 7597b212cc43460ab26d29f0f7fd6ad8-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 16:10 a7ea9c7884904b8da48b4677260db50d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 13:20 bff977f0052a45aebe8237477e9ec495-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 15:01 c28be1610eaf4cb79ec9c0e30e6123fe-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 15:55 d7d20b8a34424d50949bdac0fb3a0455-0.parquet
-rw-rw-r-- 1 yaoj yaoj 941 Mar 25 15:52 f61ea7b57b4b4d88ad2047e91b9d93ee-0.parquet
'/tmp/parquet_dataset_test/test2/cat=cat_2':
total 32
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 17:39 51cd53ffbcd34e5f97af1c8b0c256d54-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 16:08 6bffd35df261477b81016fa4c1c0769d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 14:27 7597b212cc43460ab26d29f0f7fd6ad8-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 16:10 a7ea9c7884904b8da48b4677260db50d-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 13:20 bff977f0052a45aebe8237477e9ec495-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 15:01 c28be1610eaf4cb79ec9c0e30e6123fe-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 15:55 d7d20b8a34424d50949bdac0fb3a0455-0.parquet
-rw-rw-r-- 1 yaoj yaoj 1947 Mar 25 15:52 f61ea7b57b4b4d88ad2047e91b9d93ee-0.parquet
[7]:
from dataclasses import dataclass, field
from typing import List
from fairseq2.data.parquet.fragment_loading import (
FragmentLoadingConfig, NamedColumns, ParquetFragmentLoader
)
@dataclass
class MyColumns(NamedColumns):
# Format: new_name: original_column_name
category: str = "cat"
uid: str = "id"
extra_columns: List[str] = field(default_factory=lambda: ["seq"])
fragment_config = FragmentStreamingConfig(
parquet_path=tmp_parquet_ds_path,
nb_epochs=2,
seed=42,
split_to_row_groups=True,
fragment_shuffle_window=10,
files_circular_shift=False,
partition_filters=[
'pc.field("cat") == "cat_0"', # uncomment this line to see the effect of partition filters
]
)
streamer = ParquetFragmentStreamer(config=fragment_config)
# Create the loading config
loading_config = FragmentLoadingConfig(
columns=MyColumns(),
cache=False,
rename_columns=False,
add_fragment_traces=True, # Add tracking columns
drop_null=True, # Drop rows with null values
nb_prefetch=2, # Prefetch 2 fragments
num_parallel_fragments=4, # Process 4 fragments in parallel
filters="pc.greater_equal(pc.list_value_length(pc.field('seq')), 4)", # you can comment this line out to see the effect of filters
)
# Build the loading pipeline
loader = ParquetFragmentLoader(config=loading_config)
fragment_pipeline = streamer.build_pipeline(0, 1)
loading_pipeline = loader.apply(fragment_pipeline)
result = list(iter(loading_pipeline.and_return()))
print(result[0])
pyarrow.Table
id: string
seq: list<element: int64>
child 0, element: int64
__batch_index: int32
__fragment_index: int32
__filename: string
cat: dictionary<values=string, indices=int32, ordered=0>
__row_groups_ids: int32
__index_in_fragement: int32
----
id: [["id_2","id_3","id_4","id_5","id_6","id_7","id_8","id_9"]]
seq: [[[0,1,2,3],[0,1,2,3,4],...,[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9,10]]]
__batch_index: [[0,0,0,0,0,0,0,0]]
__fragment_index: [[0,0,0,0,0,0,0,0]]
__filename: [["/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet","/tmp/parquet_dataset_test/test2/cat=cat_0/7597b212cc43460ab26d29f0f7fd6ad8-0.parquet"]]
cat: [ -- dictionary:
["cat_0"] -- indices:
[0,0,0,0,0,0,0,0]]
__row_groups_ids: [[0,0,0,0,0,0,0,0]]
__index_in_fragement: [[2,3,4,5,6,7,8,9]]
Table Bucketing¶
[15]:
from fairseq2.data.parquet.table_bucketing import (
TableBucketingConfig, TableBucketer
)
streamer = ParquetFragmentStreamer(config=fragment_config)
loader = ParquetFragmentLoader(config=loading_config)
fragment_pipeline = streamer.build_pipeline(0, 1)
loading_pipeline = loader.apply(fragment_pipeline)
# Create bucketing config
bucketing_config = TableBucketingConfig(
target_table_size=1000, # Aim for tables with 1000 rows
min_fragment_number=2, # Combine at least 2 fragments
max_fragment_number=10, # Combine at most 10 fragments
shuffle=True, # Shuffle rows in memory
batch_size=32 # Return batches of 32 rows
)
# Apply bucketing
bucketer = TableBucketer(bucketing_config)
final_pipeline = bucketer.apply(loading_pipeline).and_return()
# Iterate through batches
for batch in final_pipeline:
# batch is a PyArrow Table
print(batch.select(["id"]))
print(len(batch))
print("="*10)
pyarrow.Table
id: string
----
id: [["id_4","id_6","id_9","id_15","id_13",...,"id_6","id_19","id_17","id_14","id_8"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_8","id_18","id_5","id_17","id_7",...,"id_13","id_2","id_3","id_9","id_4"]]
16
==========
pyarrow.Table
id: string
----
id: [["id_16","id_12","id_18","id_16","id_6",...,"id_8","id_4","id_2","id_16","id_9"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_9","id_4","id_5","id_2","id_8",...,"id_18","id_12","id_8","id_9","id_7"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_15","id_5","id_13","id_4","id_18",...,"id_14","id_19","id_6","id_13","id_5"]]
16
==========
pyarrow.Table
id: string
----
id: [["id_3","id_15","id_15","id_7","id_17",...,"id_17","id_16","id_14","id_19","id_18"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_12","id_15","id_2","id_18","id_17",...,"id_3","id_12","id_15","id_7","id_9"]]
16
==========
pyarrow.Table
id: string
----
id: [["id_18","id_19","id_4","id_8","id_19",...,"id_14","id_5","id_8","id_13","id_13"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_4","id_18","id_7","id_9","id_18",...,"id_4","id_14","id_7","id_13","id_3"]]
32
==========
pyarrow.Table
id: string
----
id: [["id_5","id_13","id_2","id_6","id_16",...,"id_18","id_9","id_7","id_15","id_3"]]
16
==========
Complete Pipeline¶
[17]:
from fairseq2.data.parquet import (
BasicDataLoadingConfig,
build_basic_parquet_data_pipeline,
FragmentStreamingConfig,
FragmentLoadingConfig,
TableBucketingConfig
)
# Configure the entire pipeline
config = BasicDataLoadingConfig(
fragment_stream_config=FragmentStreamingConfig(
parquet_path=tmp_parquet_ds_path,
nb_epochs=None, # Infinite iterations
fragment_shuffle_window=100
),
fragment_load_config=FragmentLoadingConfig(
nb_prefetch=2,
num_parallel_fragments=3
),
table_bucketing_config=TableBucketingConfig(
target_table_size=1000,
min_fragment_number=2,
max_fragment_number=10,
shuffle=True,
batch_size=32
),
)
# Create the pipeline
pipeline = build_basic_parquet_data_pipeline(config).and_return()
# Use the pipeline
for batch in pipeline:
print(batch.select(["id"]))
print(len(batch))
print("="*10)
break
pyarrow.Table
id: string
----
id: [["id_42","id_41","id_21","id_36"]]
4
==========
Working w/ PyArrow Tables¶
PyArrow tables can be converted to various formats.
Using polars, one can use
pl.from_arrow(pa_table, rechunk=False)
to convert into a polars dataframe (with almost memory zero copy);pa.Table.to_pylist()
orpl.from_arrow(...).to_dicts()
(usually much faster) to convert into a list of dictionaries;parquet/utiles.py:pyarrow_table_to_torch_dict
to convert pyarrow table into a dictionary of cpu torch tensors (best effort)
[10]:
# Convert to pandas
df = batch.to_pandas()
# Convert to dictionary
batch_dict = batch.to_pydict()
# Convert to torch tensors
from fairseq2.data.parquet.utils import pyarrow_table_to_torch_dict
tensor_dict = pyarrow_table_to_torch_dict(batch)
# Using Polars (fast with zero-copy)
import polars as pl
polars_df = pl.from_arrow(batch, rechunk=False)
# Convert to list of dictionaries (rows)
rows = batch.to_pylist()
# Or using polars (usually much faster)
rows = pl.from_arrow(batch, rechunk=False).to_dicts()
/fsx-checkpoints/yaoj/envs/fs2_nightly_pt25_cu121/conda/lib/python3.10/site-packages/pandas/core/algorithms.py:1743: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return lib.map_infer(values, mapper, convert=convert)
Transformation¶
[11]:
from fairseq2.data.parquet.arrow_transform import filter_strings_by_length
# Create a custom transformation
def my_transform(table: pa.Table):
# Apply filtering by text length
table = filter_strings_by_length(table, "id", min_len=3, max_len=3)
return table
streamer = ParquetFragmentStreamer(config=fragment_config)
loader = ParquetFragmentLoader(config=loading_config)
fragment_pipeline = streamer.build_pipeline(0, 1)
loading_pipeline = loader.apply(fragment_pipeline)
# Apply the transformation
final_pipeline = loading_pipeline.map(my_transform).and_return()
# Use the pipeline
for batch in pipeline:
print(batch.to_pydict()["id"])
break
['id_4', 'id_0', 'id_33', 'id_12', 'id_29', 'id_49', 'id_4', 'id_30', 'id_2', 'id_0', 'id_17', 'id_3', 'id_8', 'id_34', 'id_5', 'id_0', 'id_18', 'id_32', 'id_4', 'id_23', 'id_3', 'id_31', 'id_13', 'id_47', 'id_1', 'id_35', 'id_33', 'id_16', 'id_38', 'id_19', 'id_20', 'id_7']