✎ Data Pipeline

Overview

This notebook presents examples of how to use DataPipeline to create a data pipeline.

[2]:
from fairseq2.data import DataPipeline, read_sequence

Combine multiple pipelines

Round Robin

The DataPipeline.round_robin method is used to combine multiple pipelines and return a new pipeline, which will yield elements from each of the input pipelines in a round-robin fashion.

[8]:
# finite pipeline
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7, 8]).and_return()
pipeline3 = read_sequence([0, 2, 4, 6]).and_return()

pipeline = DataPipeline.round_robin(
    [pipeline1, pipeline2, pipeline3]
).and_return()

for i in pipeline:
    print(i)

pipeline.reset()

for i in pipeline:
    print(i)
1
5
0
2
6
2
3
7
4
4
8
6
1
5
0
2
6
2
3
7
4
4
8
6
[10]:
# pseudo-infinite pipeline
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = DataPipeline.constant(0).and_return()
pipeline3 = read_sequence([0, 2, 4, 6]).and_return()

pipeline = DataPipeline.round_robin(
    [pipeline1, pipeline2, pipeline3]
).and_return()

for _ in range(2):
    assert list(pipeline) == [1, 0, 0, 2, 0, 2, 3, 0, 4, 4, 0, 6]

    pipeline.reset()
[11]:
# infinite pipeline
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([0]).repeat().and_return()
pipeline3 = read_sequence([0, 2, 4, 6]).and_return()

pipeline = DataPipeline.round_robin(
    [pipeline1, pipeline2, pipeline3]
).and_return()

for _ in range(2):
    it = iter(pipeline)

    [next(it) for i in range(15)] == [1, 0, 0, 2, 0, 2, 3, 0, 4, 4, 0, 6, 1, 0, 0]

    pipeline.reset()

Zip

The DataPipeline.zip method is used to combine multiple pipelines and return a new pipeline, which will yield elements from each of the input pipelines in a zip fashion.

[13]:
# finite pipeline
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7, 8]).and_return()
pipeline3 = read_sequence([0, 2, 4, 6]).and_return()

pipeline = DataPipeline.zip([pipeline1, pipeline2, pipeline3]).and_return()

for _ in range(2):
    assert list(pipeline) == [[1, 5, 0], [2, 6, 2], [3, 7, 4], [4, 8, 6]]

    pipeline.reset()
[14]:
# pseudo-infinite pipeline
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = DataPipeline.constant(0).and_return()
pipeline3 = read_sequence([5, 6, 7, 8]).and_return()

pipeline = DataPipeline.zip([pipeline1, pipeline2, pipeline3]).and_return()

for _ in range(2):
    assert list(pipeline) == [[1, 0, 5], [2, 0, 6], [3, 0, 7], [4, 0, 8]]

    pipeline.reset()
[16]:
# infinite pipeline
pipeline1 = read_sequence([0]).repeat().and_return()
pipeline2 = read_sequence([1]).repeat().and_return()
pipeline3 = read_sequence([2]).repeat().and_return()

pipeline = DataPipeline.zip([pipeline1, pipeline2, pipeline3]).and_return()

for _ in range(2):
    it = iter(pipeline)

    assert [next(it) for i in range(2)] == [[0, 1, 2], [0, 1, 2]]

    pipeline.reset()
[17]:
# flatten and dict input
pipeline1 = read_sequence([{"foo1": 1}, {"foo1": 2}, {"foo1": 3}]).and_return()
pipeline2 = read_sequence([{"foo2": 4, "foo3": 5}, {"foo2": 6, "foo3": 7}, {"foo2": 8, "foo3": 9}]).and_return()  # fmt: skip
pipeline3 = read_sequence([{"foo4": 2}, {"foo4": 3}, {"foo4": 4}]).and_return()

pipeline = DataPipeline.zip(
    [pipeline1, pipeline2, pipeline3], flatten=True
).and_return()

for _ in range(2):
    assert list(pipeline) == [
        {"foo1": 1, "foo2": 4, "foo3": 5, "foo4": 2},
        {"foo1": 2, "foo2": 6, "foo3": 7, "foo4": 3},
        {"foo1": 3, "foo2": 8, "foo3": 9, "foo4": 4},
    ]

    pipeline.reset()

Sample

The DataPipeline.sample method is used to sample elements from multiple pipelines with weights. The weights argument is a list of floats, which specifies the probability of sampling from each pipeline. The seed argument is used to set the random seed for the sampling process.

[18]:
# finite pipeline
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7]).and_return()

pipeline = DataPipeline.sample(
    [pipeline1, pipeline2], weights=[1.2, 0.8], seed=1234
).and_return()

for _ in range(2):
    assert list(pipeline) == [1, 2, 3, 4, 1, 5, 2, 3, 6, 4, 7]

    pipeline.reset(reset_rng=True)
[19]:
# pseudo-infinite pipeline
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = DataPipeline.count(5).and_return()

pipeline = DataPipeline.sample(
    [pipeline1, pipeline2], weights=[0.4, 0.6], seed=1234
).and_return()

for _ in range(2):
    assert list(pipeline) == [1, 5, 2, 3, 4]

    pipeline.reset(reset_rng=True)
[20]:
# infinite pipeline
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7, 8]).repeat().and_return()

pipeline = DataPipeline.sample(
    [pipeline1, pipeline2], weights=[0.4, 0.6], seed=1234
).and_return()

it = iter(pipeline)

assert [next(it) for i in range(10)] == [1, 5, 2, 3, 4, 6, 1, 7, 8, 2]
[21]:
# allow_repeats=False
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6]).and_return()
pipeline3 = read_sequence([7, 8, 9, 10, 11, 12]).and_return()

pipeline = DataPipeline.sample(
    [pipeline1, pipeline2, pipeline3],
    weights=[0.3, 0.6, 0.1],
    allow_repeats=False,
    seed=1234,
).and_return()

for _ in range(2):
    assert list(pipeline) == [1, 5, 2, 6, 3, 4, 7, 8, 9, 10, 11, 12]

    pipeline.reset(reset_rng=True)

Concatenate

The DataPipeline.concat method is used to concatenate multiple pipelines as a single pipeline.

[35]:
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7, 8]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

for _ in range(2):
    assert list(pipeline) == [1, 2, 3, 4, 5, 6, 7, 8]

    pipeline.reset()

Shuffle

The DataPipeline.shuffle method is used to shuffle the elements in the pipeline. The buffer_size argument is used to set the shuffle buffer size. If the buffer size is greater than the number of elements in the pipeline, or the buffer size is set to 0, the pipeline will be shuffled completely. The seed argument is used to set the random seed for the shuffling process.

[27]:
seq = list(range(1, 10))

# Shuffle buffer 100 > 10 -> full shuffle
pipeline = read_sequence(seq).shuffle(100, seed=1234).and_return()

for _ in range(2):
    assert list(pipeline) == [8, 9, 3, 7, 5, 4, 2, 6, 1]

    pipeline.reset(reset_rng=True)

# exhaust the pipeline to start a new shuffle
_ = list(pipeline)

# reset the pipeline without resetting the seed
pipeline.reset(reset_rng=False)

# We haven't reset the seed. The list should be different this time.
assert list(pipeline) != [8, 9, 3, 7, 5, 4, 2, 6, 1]
[28]:
# Shuffle the whole list by setting shuffle buffer to 0
pipeline = read_sequence(seq).shuffle(0, seed=1234).and_return()

for _ in range(2):
    assert list(pipeline) == [8, 9, 3, 7, 5, 4, 2, 6, 1]

    pipeline.reset(reset_rng=True)
[29]:
# Shuffle 4 elements per buffer
pipeline = read_sequence(seq).shuffle(4, seed=1234).and_return()

for _ in range(2):
    assert list(pipeline) == [2, 1, 3, 4, 5, 7, 8, 6, 9]

    pipeline.reset(reset_rng=True)

# Shuffle 1 element per buffer -> no shuffle
pipeline = read_sequence(seq).shuffle(1, seed=1234).and_return()

for _ in range(2):
    assert list(pipeline) == [1, 2, 3, 4, 5, 6, 7, 8, 9]

    pipeline.reset(reset_rng=True)

Bucket

The DataPipeline.bucket method is used to bucket the elements in the pipeline. This method is useful when you want to process data in batches. It also supports dynamic bucketing, where the bucket size is determined by the cost function in a custom way.

  • The bucket_size argument is used to set the bucket size (number of elements in each bucket).

  • The cost_fn argument is used to set the cost function for dynamic bucketing. It accumulates the cost of each element in the bucket.

  • The min_num_examples argument is used to set the minimum number of examples in each bucket.

  • The max_num_examples argument is used to set the maximum number of examples in each bucket.

  • The drop_remainder argument is used to set whether to drop the remaining examples in the last bucket.

[38]:
# simple bucketing
seq = list(range(100))

bucket_size = 4

pipeline = read_sequence(seq).bucket(bucket_size).and_return()

for _ in range(2):
    it = iter(pipeline)

    for i in range(25):
        d = next(it)

        offset = i * bucket_size

        assert d == [offset + i for i in range(4)]

    try:
        next(it)
    except StopIteration:
        print("StopIteration")

    pipeline.reset()
StopIteration
StopIteration
[44]:
# dynamic bucketing (square cost function)
seq = list(range(1, 7))

threshold = 14
cost_fn = lambda x: x**2  # cost function is the sum of the square of each element

pipeline = read_sequence(seq).dynamic_bucket(threshold, cost_fn).and_return()

for _ in range(2):
    it = iter(pipeline)

    assert next(it) == [1, 2, 3]
    assert next(it) == [4]
    assert next(it) == [5]
    assert next(it) == [6]

    try:
        next(it)
    except StopIteration:
        print("StopIteration")

    pipeline.reset()
StopIteration
StopIteration
[46]:
# dynamic bucketing (length cost function)
seq = [[1, 2], [3, 4, 5], [6], [7], [8, 9, 10], [11, 12, 13, 14, 15, 16]]

threshold = 5
cost_fn = lambda x: len(x)

pipeline = read_sequence(seq).dynamic_bucket(threshold, cost_fn).and_return()

for _ in range(2):
    it = iter(pipeline)

    assert next(it) == [[1, 2], [3, 4, 5]]
    assert next(it) == [[6], [7], [8, 9, 10]]
    assert next(it) == [[11, 12, 13, 14, 15, 16]]

    try:
        next(it)
    except StopIteration:
        print("StopIteration")

    pipeline.reset()
StopIteration
StopIteration
[48]:
# dynamic bucketing (more constraints)
seq = [0, 0, 0, 0, 1, 2, 3, 4, 5]

threshold = 3
cost_fn = lambda x: x

pipeline = (
    read_sequence(seq)
    .dynamic_bucket(threshold, cost_fn, min_num_examples=2, max_num_examples=2, drop_remainder=True)
    .and_return()
)

for _ in range(2):
    it = iter(pipeline)

    assert next(it) == [0, 0]
    assert next(it) == [0, 0]
    assert next(it) == [1, 2]
    assert next(it) == [3, 4]

    try:
        next(it)
    except StopIteration:
        print("StopIteration")

    pipeline.reset()
StopIteration
StopIteration

Map

The DataPipeline.map method is used to apply a function to each element in the pipeline. It supports parallel execution of the function. You can set the number of parallel calls by the num_parallel_calls argument.

map allows you to apply one or more functions to all elements in the pipeline. We have examples below to showcase both simple and complex cases. We also showcase how to apply a function to a subset of elements in the pipeline using the selector argument.

The selector argument is used to specify the elements to apply the function to. It supports the same syntax as the selector argument in the Dataset.map method. If you want to apply the function to all elements, you can set the selector to "*".

[33]:
# simple example
def fn(d: int) -> int:
    return d**2

seq = list(range(1, 10))

pipeline = read_sequence(seq).map(fn, num_parallel_calls=4).and_return()  # fmt: skip

for _ in range(2):
    assert list(pipeline) == [i**2 for i in seq]

    pipeline.reset()
[34]:
# list of functions

from fairseq2.data.text import StrToIntConverter

fn1 = StrToIntConverter()

def fn2(d: int) -> int:
    return d**2

pipeline = read_sequence(["1", "2", "3", "4"]).map([fn1, fn2]).and_return()

for _ in range(2):
    assert list(pipeline) == [1, 4, 9, 16]

    pipeline.reset()
[39]:
# a bit more complex example with a dataclass
from dataclasses import dataclass

@dataclass
class Foo:
    value: int

def fn(d: Foo) -> Foo:
    d.value += 2

    return d

pipeline = read_sequence([Foo(1), Foo(2)]).map(fn).and_return()

it = iter(pipeline)

for i in range(1, 3):
    assert next(it) == Foo(1 + (i * 2))
    assert next(it) == Foo(2 + (i * 2))

    pipeline.reset()
[40]:
# use selector to apply the function only to the selected elements
def fn1(d: int) -> int:
    return d + 10

def fn2(d: int) -> int:
    return d * 2

seq = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

pipeline = read_sequence(seq).map([fn1, fn2], selector="[1]").and_return()

for _ in range(2):
    it = iter(pipeline)

    assert next(it) == [1, 24, 3]
    assert next(it) == [4, 30, 6]
    assert next(it) == [7, 36, 9]

    pipeline.reset()
[42]:
# complex selector
# more examples below:
# [
#     "[0]",
#     "[0][1]",
#     "foo",
#     "  foo ",
#     "foo1.foo2",
#     "foo[0]",
#     "foo[0][1]",
#     "foo[*]",
#     "foo[*][2]",
#     "foo[1][*]",
#     "foo1.foo2[0]",
#     "foo1,foo2",
#     "foo1[0],foo2[0]",
#     " foo1[0]  , foo2[1],foo3,  foo[*][3]",
# ]

import copy

def fn(d: int) -> int:
    return d + 10

d1 = {
    "foo1": 1,
    "foo2": [2, 3, {"foo4": 4}],
    "foo3": [5],
    "foo5": {"foo6": {"foo7": 1}},
}
d2 = {
    "foo1": 6,
    "foo2": [7, 8, {"foo4": 9}],
    "foo3": [0],
    "foo5": {"foo6": {"foo7": 2}},
}

e1 = copy.deepcopy(d1)
e2 = copy.deepcopy(d2)

e1["foo1"] = 11
e2["foo1"] = 16
e1["foo2"][2]["foo4"] = 14  # type: ignore[index]
e2["foo2"][2]["foo4"] = 19  # type: ignore[index]
e1["foo3"] = [15]
e2["foo3"] = [10]
e1["foo5"]["foo6"]["foo7"] = 11  # type: ignore[index]
e2["foo5"]["foo6"]["foo7"] = 12  # type: ignore[index]

selector = "foo2[2].foo4,foo3[0], foo1,foo5.foo6.foo7"

pipeline = read_sequence([d1, d2]).map(fn, selector=selector).and_return()

for _ in range(2):
    it = iter(pipeline)

    assert next(it) == e1
    assert next(it) == e2

    pipeline.reset()

Stateful Operations

The DataPipeline.state_dict method is used to save the state of the pipeline. The DataPipeline.load_state_dict method is used to restore the pipeline from a state dict.

This is useful when you want to save the state of the pipeline and restore it later.

[36]:
# this example explains how to restore the pipeline from a state dict
pipeline1 = read_sequence([1, 2, 3, 4]).and_return()
pipeline2 = read_sequence([5, 6, 7, 8]).and_return()

pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()

d = None

it = iter(pipeline)

# Move to the second example.
for _ in range(6):
    d = next(it)

assert d == 6

state_dict = pipeline.state_dict()

# Read one more example before we roll back.
d = next(it)

assert d == 7

# Expected to roll back to the second example.
pipeline.load_state_dict(state_dict)

# Move to EOD.
for _ in range(2):
    d = next(it)

assert d == 8

state_dict = pipeline.state_dict()

pipeline.reset()

# Expected to be EOD.
pipeline.load_state_dict(state_dict)

try:
    # this should raise StopIteration
    next(iter(pipeline))
except StopIteration:
    print("StopIteration")
StopIteration