Data formats

This example benchmarks the speed of loading data in different formats.

See Case Studies / Data Format for the detail of how data format and the loading function affects the performance of the training pipeline.

Source

Source

Click here to see the source.
  1#!/usr/bin/env python3
  2# Copyright (c) Meta Platforms, Inc. and affiliates.
  3# All rights reserved.
  4#
  5# This source code is licensed under the BSD-style license found in the
  6# LICENSE file in the root directory of this source tree.
  7
  8"""This example benchmarks the speed of loading data in different formats.
  9
 10See `Case Studies / Data Format <../case_studies/data_format.html>`_ for
 11the detail of how data format and the loading function affects
 12the performance of the training pipeline.
 13"""
 14
 15__all__ = [
 16    "main",
 17    "get_mock_data",
 18    "get_pipeline",
 19    "load_npy",
 20    "load_npy_spdl",
 21    "load_torch",
 22    "run_pipeline",
 23    "DataSource",
 24]
 25
 26import time
 27from collections.abc import Callable, Iterable, Iterator
 28from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
 29from io import BytesIO
 30from typing import Generic, TypeVar
 31
 32import numpy as np
 33import spdl.io
 34import torch
 35from numpy.typing import NDArray
 36from spdl.pipeline import Pipeline, PipelineBuilder
 37
 38# pyre-strict
 39
 40T = TypeVar("T")
 41
 42
 43def get_pipeline(
 44    src: Iterable[T],
 45    load_fn: Callable[[T], list[NDArray]],
 46    num_workers: int,
 47    mode: str,
 48) -> Pipeline:
 49    """Build a pipeline to iterate the source with the load function in different parallelism.
 50
 51    Args:
 52        src: The data source.
 53        load_fn: The function that loads NumPy NDArray from the source byte string.
 54        num_workers: The number of worker threads or processes.
 55        mode: The mode of parallelism. The valid values are ``"mt"`` (multi-threading)
 56            and ``"mp"`` (multi-processing).
 57
 58    Returns:
 59        The resulting pipeline.
 60    """
 61    match mode:
 62        case "mt":
 63            executor = ThreadPoolExecutor(num_workers)
 64        case "mp":
 65            executor = ProcessPoolExecutor(num_workers)
 66        case _:
 67            raise ValueError(f'The `mode` must be either "mt" or "mp". Found: {mode}')
 68
 69    return (
 70        PipelineBuilder()
 71        .add_source(src)
 72        .pipe(load_fn, concurrency=num_workers, executor=executor)
 73        .add_sink(buffer_size=1)
 74        .build(num_threads=1)
 75    )
 76
 77
 78def load_npy(items: list[bytes]) -> list[NDArray]:
 79    """Load arrays from serialized NPY binary strings using :py:func:`numpy.load`."""
 80    return [np.load(BytesIO(item), allow_pickle=False) for item in items]
 81
 82
 83def load_npy_spdl(items: list[bytes]) -> list[NDArray]:
 84    """Load arrays from serialized NPY binary strings using :py:func:`spdl.io.load_npy`."""
 85    return [spdl.io.load_npy(item) for item in items]
 86
 87
 88def load_npz(item: bytes) -> list[NDArray]:
 89    """Load arrays from a serialized NPZ binary string using :py:func:`numpy.load`."""
 90    data = np.load(BytesIO(item))
 91    return list(data.values())
 92
 93
 94def load_npz_spdl(item: bytes) -> list[NDArray]:
 95    """Load arrays from serialized NPZ binary strings using :py:func:`spdl.io.load_npz`."""
 96    data = spdl.io.load_npz(item)
 97    return list(data.values())
 98
 99
100def load_torch(item: bytes) -> list[NDArray]:
101    """Load arrays from a serialized PyTorch state dict."""
102    return list(torch.load(BytesIO(item)).values())
103
104
105def _get_load_fn(
106    data_format: str, impl: str
107) -> Callable[[list[bytes]], list[NDArray]] | Callable[[bytes], list[NDArray]]:
108    match data_format:
109        case "torch":
110            return load_torch
111        case "npy":
112            if impl == "spdl":
113                return load_npy_spdl
114            return load_npy
115        case "npz":
116            if impl == "spdl":
117                return load_npz_spdl
118            return load_npz
119        case _:
120            raise ValueError(f"Unexpected data format: {data_format}")
121
122
123class DataSource(Generic[T]):
124    """Keep yielding the same data given times.
125
126    Args:
127        data: Data to be yielded.
128        repeat: The number of yields.
129    """
130
131    def __init__(self, data: T, repeat: int) -> None:
132        self.data = data
133        self.repeat = repeat
134
135    def __iter__(self) -> Iterator[T]:
136        for _ in range(self.repeat):
137            yield self.data
138
139
140def run_pipeline(pipeline: Pipeline[...]) -> tuple[int, float]:
141    """Run the pipeline and measure the time."""
142    t0 = time.monotonic()
143    with pipeline.auto_stop():
144        num_items = 0
145        for _ in pipeline:
146            num_items += 1
147    elapsed = time.monotonic() - t0
148    return num_items, elapsed
149
150
151def _dump_np(arr: NDArray | dict[str, NDArray], compressed: bool = False) -> bytes:
152    with BytesIO() as buf:
153        if isinstance(arr, dict):
154            if compressed:
155                np.savez_compressed(buf, allow_pickle=False, **arr)
156            else:
157                np.savez(buf, allow_pickle=False, **arr)
158        else:
159            np.save(buf, arr, allow_pickle=False)
160        buf.seek(0)
161        return buf.read()
162
163
164def _dump_torch(arr: dict[str, NDArray]) -> bytes:
165    with BytesIO() as buf:
166        torch.save({k: torch.from_numpy(v) for k, v in arr.items()}, buf)
167        buf.seek(0)
168        return buf.read()
169
170
171def get_mock_data(format: str, compressed: bool = False) -> tuple[bytes, bytes] | bytes:
172    """Generate a single sample in the given format.
173
174    The mock data resemboles an RGB image and its segmentation labels.
175
176    Args:
177        format: One of ``"npz"``, ``"npy"`` or ``"torch"``.
178        compressed: If ``True``, NPZ file is compressed.
179            (i.e. :py:func:`numpy.savez_compressed` is used.)
180
181    Returns:
182        Serialized mock arrays. If ``"npy"`` then arrays are serialized
183        separately. Otherwise arrays are bundled together.
184    """
185    img = np.random.randint(256, size=(3, 640, 480), dtype=np.uint8)
186    lbl = np.random.randint(256, size=(640, 480), dtype=np.uint8)
187
188    match format:
189        case "npz":
190            return _dump_np({"img": img, "lbl": lbl}, compressed=compressed)
191        case "npy":
192            return _dump_np(img), _dump_np(lbl)
193        case "torch":
194            return _dump_torch({"img": img, "lbl": lbl})
195        case _:
196            raise ValueError(f"Unexpected `format`: {format}")
197
198
199def main() -> None:
200    """The entrypoint from CLI."""
201    configs = [
202        ("torch", False, "torch"),
203        ("npy", False, "np"),
204        ("npy", False, "spdl"),
205        ("npz", False, "np"),
206        ("npz", True, "np"),
207        ("npz", False, "spdl"),
208        ("npz", True, "spdl"),
209    ]
210    for data_format, compressed, impl in configs:
211        src = DataSource(get_mock_data(data_format, compressed), repeat=1000)
212        load_fn = _get_load_fn(data_format, impl)
213        for mode in ["mp", "mt"]:
214            for num_workers in [1, 2, 4, 8, 16, 32]:
215                pipeline = get_pipeline(
216                    src,  # pyre-ignore: [6]
217                    load_fn,
218                    num_workers,
219                    mode,
220                )
221                num_items, elapsed = run_pipeline(pipeline)
222                qps = num_items / elapsed
223                print(f"{data_format},{compressed},{impl},{mode},{num_workers},{qps}")
224
225
226if __name__ == "__main__":
227    main()

Functions

Functions

main() None[source]

The entrypoint from CLI.

get_mock_data(format: str, compressed: bool = False) tuple[bytes, bytes] | bytes[source]

Generate a single sample in the given format.

The mock data resemboles an RGB image and its segmentation labels.

Parameters:
  • format – One of "npz", "npy" or "torch".

  • compressed – If True, NPZ file is compressed. (i.e. numpy.savez_compressed() is used.)

Returns:

Serialized mock arrays. If "npy" then arrays are serialized separately. Otherwise arrays are bundled together.

get_pipeline(src: Iterable[T], load_fn: Callable[[T], list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]]], num_workers: int, mode: str) Pipeline[source]

Build a pipeline to iterate the source with the load function in different parallelism.

Parameters:
  • src – The data source.

  • load_fn – The function that loads NumPy NDArray from the source byte string.

  • num_workers – The number of worker threads or processes.

  • mode – The mode of parallelism. The valid values are "mt" (multi-threading) and "mp" (multi-processing).

Returns:

The resulting pipeline.

load_npy(items: list[bytes]) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]][source]

Load arrays from serialized NPY binary strings using numpy.load().

load_npy_spdl(items: list[bytes]) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]][source]

Load arrays from serialized NPY binary strings using spdl.io.load_npy().

load_torch(item: bytes) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]][source]

Load arrays from a serialized PyTorch state dict.

run_pipeline(pipeline: Pipeline[...]) tuple[int, float][source]

Run the pipeline and measure the time.

Classes

Classes

class DataSource(data: T, repeat: int)[source]

Keep yielding the same data given times.

Parameters:
  • data – Data to be yielded.

  • repeat – The number of yields.

__iter__() Iterator[T][source]