{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ✎ Data Pipeline\n", "\n", "## Overview\n", "\n", "This notebook presents examples of how to use `DataPipeline` to create a data pipeline." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from fairseq2.data import DataPipeline, read_sequence" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combine multiple pipelines" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Round Robin\n", "\n", "The `DataPipeline.round_robin` method is used to combine multiple pipelines and return a new pipeline,\n", "which will yield elements from each of the input pipelines in a round-robin fashion." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "5\n", "0\n", "2\n", "6\n", "2\n", "3\n", "7\n", "4\n", "4\n", "8\n", "6\n", "1\n", "5\n", "0\n", "2\n", "6\n", "2\n", "3\n", "7\n", "4\n", "4\n", "8\n", "6\n" ] } ], "source": [ "# finite pipeline\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = read_sequence([5, 6, 7, 8]).and_return()\n", "pipeline3 = read_sequence([0, 2, 4, 6]).and_return()\n", "\n", "pipeline = DataPipeline.round_robin(\n", " [pipeline1, pipeline2, pipeline3]\n", ").and_return()\n", "\n", "for i in pipeline:\n", " print(i)\n", "\n", "pipeline.reset()\n", "\n", "for i in pipeline:\n", " print(i)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# pseudo-infinite pipeline\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = DataPipeline.constant(0).and_return()\n", "pipeline3 = read_sequence([0, 2, 4, 6]).and_return()\n", "\n", "pipeline = DataPipeline.round_robin(\n", " [pipeline1, pipeline2, pipeline3]\n", ").and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [1, 0, 0, 2, 0, 2, 3, 0, 4, 4, 0, 6]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# infinite pipeline\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = read_sequence([0]).repeat().and_return()\n", "pipeline3 = read_sequence([0, 2, 4, 6]).and_return()\n", "\n", "pipeline = DataPipeline.round_robin(\n", " [pipeline1, pipeline2, pipeline3]\n", ").and_return()\n", "\n", "for _ in range(2):\n", " it = iter(pipeline)\n", "\n", " [next(it) for i in range(15)] == [1, 0, 0, 2, 0, 2, 3, 0, 4, 4, 0, 6, 1, 0, 0]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Zip\n", "\n", "The `DataPipeline.zip` method is used to combine multiple pipelines and return a new pipeline,\n", "which will yield elements from each of the input pipelines in a zip fashion." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# finite pipeline\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = read_sequence([5, 6, 7, 8]).and_return()\n", "pipeline3 = read_sequence([0, 2, 4, 6]).and_return()\n", "\n", "pipeline = DataPipeline.zip([pipeline1, pipeline2, pipeline3]).and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [[1, 5, 0], [2, 6, 2], [3, 7, 4], [4, 8, 6]]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# pseudo-infinite pipeline\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = DataPipeline.constant(0).and_return()\n", "pipeline3 = read_sequence([5, 6, 7, 8]).and_return()\n", "\n", "pipeline = DataPipeline.zip([pipeline1, pipeline2, pipeline3]).and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [[1, 0, 5], [2, 0, 6], [3, 0, 7], [4, 0, 8]]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# infinite pipeline\n", "pipeline1 = read_sequence([0]).repeat().and_return()\n", "pipeline2 = read_sequence([1]).repeat().and_return()\n", "pipeline3 = read_sequence([2]).repeat().and_return()\n", "\n", "pipeline = DataPipeline.zip([pipeline1, pipeline2, pipeline3]).and_return()\n", "\n", "for _ in range(2):\n", " it = iter(pipeline)\n", "\n", " assert [next(it) for i in range(2)] == [[0, 1, 2], [0, 1, 2]]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# flatten and dict input\n", "pipeline1 = read_sequence([{\"foo1\": 1}, {\"foo1\": 2}, {\"foo1\": 3}]).and_return()\n", "pipeline2 = read_sequence([{\"foo2\": 4, \"foo3\": 5}, {\"foo2\": 6, \"foo3\": 7}, {\"foo2\": 8, \"foo3\": 9}]).and_return() # fmt: skip\n", "pipeline3 = read_sequence([{\"foo4\": 2}, {\"foo4\": 3}, {\"foo4\": 4}]).and_return()\n", "\n", "pipeline = DataPipeline.zip(\n", " [pipeline1, pipeline2, pipeline3], flatten=True\n", ").and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [\n", " {\"foo1\": 1, \"foo2\": 4, \"foo3\": 5, \"foo4\": 2},\n", " {\"foo1\": 2, \"foo2\": 6, \"foo3\": 7, \"foo4\": 3},\n", " {\"foo1\": 3, \"foo2\": 8, \"foo3\": 9, \"foo4\": 4},\n", " ]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sample\n", "\n", "The `DataPipeline.sample` method is used to sample elements from multiple pipelines with weights.\n", "The `weights` argument is a list of floats, which specifies the probability of sampling from each pipeline.\n", "The `seed` argument is used to set the random seed for the sampling process." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# finite pipeline\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = read_sequence([5, 6, 7]).and_return()\n", "\n", "pipeline = DataPipeline.sample(\n", " [pipeline1, pipeline2], weights=[1.2, 0.8], seed=1234\n", ").and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [1, 2, 3, 4, 1, 5, 2, 3, 6, 4, 7]\n", "\n", " pipeline.reset(reset_rng=True)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# pseudo-infinite pipeline\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = DataPipeline.count(5).and_return()\n", "\n", "pipeline = DataPipeline.sample(\n", " [pipeline1, pipeline2], weights=[0.4, 0.6], seed=1234\n", ").and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [1, 5, 2, 3, 4]\n", "\n", " pipeline.reset(reset_rng=True)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# infinite pipeline\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = read_sequence([5, 6, 7, 8]).repeat().and_return()\n", "\n", "pipeline = DataPipeline.sample(\n", " [pipeline1, pipeline2], weights=[0.4, 0.6], seed=1234\n", ").and_return()\n", "\n", "it = iter(pipeline)\n", "\n", "assert [next(it) for i in range(10)] == [1, 5, 2, 3, 4, 6, 1, 7, 8, 2]" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# allow_repeats=False\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = read_sequence([5, 6]).and_return()\n", "pipeline3 = read_sequence([7, 8, 9, 10, 11, 12]).and_return()\n", "\n", "pipeline = DataPipeline.sample(\n", " [pipeline1, pipeline2, pipeline3],\n", " weights=[0.3, 0.6, 0.1],\n", " allow_repeats=False,\n", " seed=1234,\n", ").and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [1, 5, 2, 6, 3, 4, 7, 8, 9, 10, 11, 12]\n", "\n", " pipeline.reset(reset_rng=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Concatenate\n", "\n", "The `DataPipeline.concat` method is used to concatenate multiple pipelines as a single pipeline." ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = read_sequence([5, 6, 7, 8]).and_return()\n", "\n", "pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [1, 2, 3, 4, 5, 6, 7, 8]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Shuffle\n", "\n", "The `DataPipeline.shuffle` method is used to shuffle the elements in the pipeline.\n", "The `buffer_size` argument is used to set the shuffle buffer size.\n", "If the buffer size is greater than the number of elements in the pipeline, or the buffer size is set to 0, the pipeline will be shuffled completely.\n", "The `seed` argument is used to set the random seed for the shuffling process." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "seq = list(range(1, 10))\n", "\n", "# Shuffle buffer 100 > 10 -> full shuffle\n", "pipeline = read_sequence(seq).shuffle(100, seed=1234).and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [8, 9, 3, 7, 5, 4, 2, 6, 1]\n", "\n", " pipeline.reset(reset_rng=True)\n", "\n", "# exhaust the pipeline to start a new shuffle\n", "_ = list(pipeline)\n", "\n", "# reset the pipeline without resetting the seed\n", "pipeline.reset(reset_rng=False)\n", "\n", "# We haven't reset the seed. The list should be different this time.\n", "assert list(pipeline) != [8, 9, 3, 7, 5, 4, 2, 6, 1]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "# Shuffle the whole list by setting shuffle buffer to 0\n", "pipeline = read_sequence(seq).shuffle(0, seed=1234).and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [8, 9, 3, 7, 5, 4, 2, 6, 1]\n", "\n", " pipeline.reset(reset_rng=True)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "# Shuffle 4 elements per buffer\n", "pipeline = read_sequence(seq).shuffle(4, seed=1234).and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [2, 1, 3, 4, 5, 7, 8, 6, 9]\n", "\n", " pipeline.reset(reset_rng=True)\n", "\n", "# Shuffle 1 element per buffer -> no shuffle\n", "pipeline = read_sequence(seq).shuffle(1, seed=1234).and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [1, 2, 3, 4, 5, 6, 7, 8, 9]\n", "\n", " pipeline.reset(reset_rng=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bucket\n", "\n", "The `DataPipeline.bucket` method is used to bucket the elements in the pipeline.\n", "This method is useful when you want to process data in batches.\n", "It also supports dynamic bucketing, where the bucket size is determined by the cost function in a custom way.\n", "- The `bucket_size` argument is used to set the bucket size (number of elements in each bucket).\n", "- The `cost_fn` argument is used to set the cost function for dynamic bucketing. It accumulates the cost of each element in the bucket.\n", "- The `min_num_examples` argument is used to set the minimum number of examples in each bucket.\n", "- The `max_num_examples` argument is used to set the maximum number of examples in each bucket.\n", "- The `drop_remainder` argument is used to set whether to drop the remaining examples in the last bucket." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StopIteration\n", "StopIteration\n" ] } ], "source": [ "# simple bucketing\n", "seq = list(range(100))\n", "\n", "bucket_size = 4\n", "\n", "pipeline = read_sequence(seq).bucket(bucket_size).and_return()\n", "\n", "for _ in range(2):\n", " it = iter(pipeline)\n", "\n", " for i in range(25):\n", " d = next(it)\n", "\n", " offset = i * bucket_size\n", "\n", " assert d == [offset + i for i in range(4)]\n", "\n", " try:\n", " next(it)\n", " except StopIteration:\n", " print(\"StopIteration\")\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StopIteration\n", "StopIteration\n" ] } ], "source": [ "# dynamic bucketing (square cost function)\n", "seq = list(range(1, 7))\n", "\n", "threshold = 14\n", "cost_fn = lambda x: x**2 # cost function is the sum of the square of each element\n", "\n", "pipeline = read_sequence(seq).dynamic_bucket(threshold, cost_fn).and_return()\n", "\n", "for _ in range(2):\n", " it = iter(pipeline)\n", "\n", " assert next(it) == [1, 2, 3]\n", " assert next(it) == [4]\n", " assert next(it) == [5]\n", " assert next(it) == [6]\n", "\n", " try:\n", " next(it)\n", " except StopIteration:\n", " print(\"StopIteration\")\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StopIteration\n", "StopIteration\n" ] } ], "source": [ "# dynamic bucketing (length cost function)\n", "seq = [[1, 2], [3, 4, 5], [6], [7], [8, 9, 10], [11, 12, 13, 14, 15, 16]]\n", "\n", "threshold = 5\n", "cost_fn = lambda x: len(x)\n", "\n", "pipeline = read_sequence(seq).dynamic_bucket(threshold, cost_fn).and_return()\n", "\n", "for _ in range(2):\n", " it = iter(pipeline)\n", "\n", " assert next(it) == [[1, 2], [3, 4, 5]]\n", " assert next(it) == [[6], [7], [8, 9, 10]]\n", " assert next(it) == [[11, 12, 13, 14, 15, 16]]\n", "\n", " try:\n", " next(it)\n", " except StopIteration:\n", " print(\"StopIteration\")\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StopIteration\n", "StopIteration\n" ] } ], "source": [ "# dynamic bucketing (more constraints)\n", "seq = [0, 0, 0, 0, 1, 2, 3, 4, 5]\n", "\n", "threshold = 3\n", "cost_fn = lambda x: x\n", "\n", "pipeline = (\n", " read_sequence(seq)\n", " .dynamic_bucket(threshold, cost_fn, min_num_examples=2, max_num_examples=2, drop_remainder=True)\n", " .and_return()\n", ")\n", "\n", "for _ in range(2):\n", " it = iter(pipeline)\n", "\n", " assert next(it) == [0, 0]\n", " assert next(it) == [0, 0]\n", " assert next(it) == [1, 2]\n", " assert next(it) == [3, 4]\n", "\n", " try:\n", " next(it)\n", " except StopIteration:\n", " print(\"StopIteration\")\n", "\n", " pipeline.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Map\n", "\n", "The `DataPipeline.map` method is used to apply a function to each element in the pipeline.\n", "It supports parallel execution of the function.\n", "You can set the number of parallel calls by the `num_parallel_calls` argument.\n", "\n", "`map` allows you to apply one or more functions to all elements in the pipeline.\n", "We have examples below to showcase both simple and complex cases.\n", "We also showcase how to apply a function to a subset of elements in the pipeline using the `selector` argument.\n", "\n", "The `selector` argument is used to specify the elements to apply the function to.\n", "It supports the same syntax as the `selector` argument in the `Dataset.map` method.\n", "If you want to apply the function to all elements, you can set the `selector` to `\"*\"`.\n" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# simple example\n", "def fn(d: int) -> int:\n", " return d**2\n", "\n", "seq = list(range(1, 10))\n", "\n", "pipeline = read_sequence(seq).map(fn, num_parallel_calls=4).and_return() # fmt: skip\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [i**2 for i in seq]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# list of functions\n", "\n", "from fairseq2.data.text import StrToIntConverter\n", "\n", "fn1 = StrToIntConverter()\n", "\n", "def fn2(d: int) -> int:\n", " return d**2\n", "\n", "pipeline = read_sequence([\"1\", \"2\", \"3\", \"4\"]).map([fn1, fn2]).and_return()\n", "\n", "for _ in range(2):\n", " assert list(pipeline) == [1, 4, 9, 16]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "# a bit more complex example with a dataclass\n", "from dataclasses import dataclass\n", "\n", "@dataclass\n", "class Foo:\n", " value: int\n", "\n", "def fn(d: Foo) -> Foo:\n", " d.value += 2\n", "\n", " return d\n", "\n", "pipeline = read_sequence([Foo(1), Foo(2)]).map(fn).and_return()\n", "\n", "it = iter(pipeline)\n", "\n", "for i in range(1, 3):\n", " assert next(it) == Foo(1 + (i * 2))\n", " assert next(it) == Foo(2 + (i * 2))\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "# use selector to apply the function only to the selected elements\n", "def fn1(d: int) -> int:\n", " return d + 10\n", "\n", "def fn2(d: int) -> int:\n", " return d * 2\n", "\n", "seq = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]\n", "\n", "pipeline = read_sequence(seq).map([fn1, fn2], selector=\"[1]\").and_return()\n", "\n", "for _ in range(2):\n", " it = iter(pipeline)\n", "\n", " assert next(it) == [1, 24, 3]\n", " assert next(it) == [4, 30, 6]\n", " assert next(it) == [7, 36, 9]\n", "\n", " pipeline.reset()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "# complex selector\n", "# more examples below:\n", "# [\n", "# \"[0]\",\n", "# \"[0][1]\",\n", "# \"foo\",\n", "# \" foo \",\n", "# \"foo1.foo2\",\n", "# \"foo[0]\",\n", "# \"foo[0][1]\",\n", "# \"foo[*]\",\n", "# \"foo[*][2]\",\n", "# \"foo[1][*]\",\n", "# \"foo1.foo2[0]\",\n", "# \"foo1,foo2\",\n", "# \"foo1[0],foo2[0]\",\n", "# \" foo1[0] , foo2[1],foo3, foo[*][3]\",\n", "# ]\n", "\n", "import copy\n", "\n", "def fn(d: int) -> int:\n", " return d + 10\n", "\n", "d1 = {\n", " \"foo1\": 1,\n", " \"foo2\": [2, 3, {\"foo4\": 4}],\n", " \"foo3\": [5],\n", " \"foo5\": {\"foo6\": {\"foo7\": 1}},\n", "}\n", "d2 = {\n", " \"foo1\": 6,\n", " \"foo2\": [7, 8, {\"foo4\": 9}],\n", " \"foo3\": [0],\n", " \"foo5\": {\"foo6\": {\"foo7\": 2}},\n", "}\n", "\n", "e1 = copy.deepcopy(d1)\n", "e2 = copy.deepcopy(d2)\n", "\n", "e1[\"foo1\"] = 11\n", "e2[\"foo1\"] = 16\n", "e1[\"foo2\"][2][\"foo4\"] = 14 # type: ignore[index]\n", "e2[\"foo2\"][2][\"foo4\"] = 19 # type: ignore[index]\n", "e1[\"foo3\"] = [15]\n", "e2[\"foo3\"] = [10]\n", "e1[\"foo5\"][\"foo6\"][\"foo7\"] = 11 # type: ignore[index]\n", "e2[\"foo5\"][\"foo6\"][\"foo7\"] = 12 # type: ignore[index]\n", "\n", "selector = \"foo2[2].foo4,foo3[0], foo1,foo5.foo6.foo7\"\n", "\n", "pipeline = read_sequence([d1, d2]).map(fn, selector=selector).and_return()\n", "\n", "for _ in range(2):\n", " it = iter(pipeline)\n", "\n", " assert next(it) == e1\n", " assert next(it) == e2\n", "\n", " pipeline.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stateful Operations\n", "\n", "The `DataPipeline.state_dict` method is used to save the state of the pipeline.\n", "The `DataPipeline.load_state_dict` method is used to restore the pipeline from a state dict.\n", "\n", "This is useful when you want to save the state of the pipeline and restore it later." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StopIteration\n" ] } ], "source": [ "# this example explains how to restore the pipeline from a state dict\n", "pipeline1 = read_sequence([1, 2, 3, 4]).and_return()\n", "pipeline2 = read_sequence([5, 6, 7, 8]).and_return()\n", "\n", "pipeline = DataPipeline.concat([pipeline1, pipeline2]).and_return()\n", "\n", "d = None\n", "\n", "it = iter(pipeline)\n", "\n", "# Move to the second example.\n", "for _ in range(6):\n", " d = next(it)\n", "\n", "assert d == 6\n", "\n", "state_dict = pipeline.state_dict()\n", "\n", "# Read one more example before we roll back.\n", "d = next(it)\n", "\n", "assert d == 7\n", "\n", "# Expected to roll back to the second example.\n", "pipeline.load_state_dict(state_dict)\n", "\n", "# Move to EOD.\n", "for _ in range(2):\n", " d = next(it)\n", "\n", "assert d == 8\n", "\n", "state_dict = pipeline.state_dict()\n", "\n", "pipeline.reset()\n", "\n", "# Expected to be EOD.\n", "pipeline.load_state_dict(state_dict)\n", "\n", "try: \n", " # this should raise StopIteration\n", " next(iter(pipeline))\n", "except StopIteration:\n", " print(\"StopIteration\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }