Data formats

This example benchmarks the speed of loading data in different formats.

See Case Studies / Data Format for the detail of how data format and the loading function affects the performance of the training pipeline.

Source

Source

Click here to see the source.
  1#!/usr/bin/env python3
  2# Copyright (c) Meta Platforms, Inc. and affiliates.
  3# All rights reserved.
  4#
  5# This source code is licensed under the BSD-style license found in the
  6# LICENSE file in the root directory of this source tree.
  7
  8"""This example benchmarks the speed of loading data in different formats.
  9
 10See `Case Studies / Data Format <../case_studies/data_format.html>`_ for
 11the detail of how data format and the loading function affects
 12the performance of the training pipeline.
 13"""
 14
 15__all__ = [
 16    "main",
 17    "get_mock_data",
 18    "get_pipeline",
 19    "load_npy",
 20    "load_npy_spdl",
 21    "load_torch",
 22    "run_pipeline",
 23    "DataSource",
 24]
 25
 26import time
 27from collections.abc import Callable, Iterable, Iterator
 28from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
 29from io import BytesIO
 30from typing import Generic, TypeVar
 31
 32import numpy as np
 33import spdl.io
 34import torch
 35from numpy.typing import NDArray
 36from spdl.pipeline import Pipeline, PipelineBuilder
 37
 38# pyre-strict
 39
 40T = TypeVar("T")
 41
 42
 43def get_pipeline(
 44    src: Iterable[T],
 45    load_fn: Callable[[T], list[NDArray]],
 46    num_workers: int,
 47    mode: str,
 48) -> Pipeline:
 49    """Build a pipeline to iterate the source with the load function in different parallelism.
 50
 51    Args:
 52        src: The data source.
 53        load_fn: The function that loads NumPy NDArray from the source byte string.
 54        num_workers: The number of worker threads or processes.
 55        mode: The mode of parallelism. The valid values are ``"mt"`` (multi-threading)
 56            and ``"mp"`` (multi-processing).
 57
 58    Returns:
 59        The resulting pipeline.
 60    """
 61    match mode:
 62        case "mt":
 63            executor = ThreadPoolExecutor(num_workers)
 64        case "mp":
 65            executor = ProcessPoolExecutor(num_workers)
 66        case _:
 67            raise ValueError(f'The `mode` must be either "mt" or "mp". Found: {mode}')
 68
 69    return (
 70        PipelineBuilder()
 71        .add_source(src)
 72        .pipe(load_fn, concurrency=num_workers, executor=executor)
 73        .add_sink(buffer_size=1)
 74        .build(num_threads=1)
 75    )
 76
 77
 78def load_npy(items: list[bytes]) -> list[NDArray]:
 79    """Load arrays from serialized NPY binary strings using :py:func:`numpy.load`."""
 80    return [np.load(BytesIO(item), allow_pickle=False) for item in items]
 81
 82
 83def load_npy_spdl(items: list[bytes]) -> list[NDArray]:
 84    """Load arrays from serialized NPY binary strings using :py:func:`spdl.io.load_npy`."""
 85    return [spdl.io.load_npy(item) for item in items]
 86
 87
 88def load_npz(item: bytes) -> list[NDArray]:
 89    """Load arrays from a serialized NPZ binary string using :py:func:`numpy.load`."""
 90    data = np.load(BytesIO(item))
 91    return list(data.values())
 92
 93
 94def load_torch(item: bytes) -> list[NDArray]:
 95    """Load arrays from a serialized PyTorch state dict."""
 96    return list(torch.load(BytesIO(item)).values())
 97
 98
 99def _get_load_fn(
100    data_format: str, provider: str
101) -> Callable[[list[bytes]], list[NDArray]] | Callable[[bytes], list[NDArray]]:
102    match data_format:
103        case "torch":
104            return load_torch
105        case "npy":
106            if provider == "spdl":
107                return load_npy_spdl
108            return load_npy
109        case "npz":
110            return load_npz
111        case _:
112            raise ValueError(f"Unexpected data format: {data_format}")
113
114
115class DataSource(Generic[T]):
116    """Keep yielding the same data given times.
117
118    Args:
119        data: Data to be yielded.
120        repeat: The number of yields.
121    """
122
123    def __init__(self, data: T, repeat: int) -> None:
124        self.data = data
125        self.repeat = repeat
126
127    def __iter__(self) -> Iterator[T]:
128        for _ in range(self.repeat):
129            yield self.data
130
131
132def run_pipeline(pipeline: Pipeline[...]) -> tuple[int, float]:
133    """Run the pipeline and measure the time."""
134    t0 = time.monotonic()
135    with pipeline.auto_stop():
136        num_items = 0
137        for _ in pipeline:
138            num_items += 1
139    elapsed = time.monotonic() - t0
140    return num_items, elapsed
141
142
143def _dump_np(arr: NDArray | dict[str, NDArray]) -> bytes:
144    with BytesIO() as buf:
145        if isinstance(arr, dict):
146            np.savez(buf, **arr, allow_pickle=False)
147        else:
148            np.save(buf, arr, allow_pickle=False)
149        buf.seek(0)
150        return buf.read()
151
152
153def _dump_torch(arr: dict[str, NDArray]) -> bytes:
154    with BytesIO() as buf:
155        torch.save({k: torch.from_numpy(v) for k, v in arr.items()}, buf)
156        buf.seek(0)
157        return buf.read()
158
159
160def get_mock_data(format: str) -> tuple[bytes, bytes] | bytes:
161    """Generate a single sample in the given format.
162
163    The mock data resemboles an RGB image and its segmentation labels.
164
165    Args:
166        format: One of ``"npz"``, ``"npy"`` or ``"torch"``.
167
168    Returns:
169        Serialized mock arrays. If ``"npy"`` then arrays are serialized
170        separately. Otherwise arrays are bundled together.
171    """
172    img = np.random.randint(256, size=(3, 640, 480), dtype=np.uint8)
173    lbl = np.random.randint(256, size=(640, 480), dtype=np.uint8)
174
175    match format:
176        case "npz":
177            return _dump_np({"img": img, "lbl": lbl})
178        case "npy":
179            return _dump_np(img), _dump_np(lbl)
180        case "torch":
181            return _dump_torch({"img": img, "lbl": lbl})
182        case _:
183            raise ValueError(f"Unexpected `format`: {format}")
184
185
186def main() -> None:
187    """The entrypoint from CLI."""
188    configs = [
189        ("torch", "torch"),
190        ("npy", "npy"),
191        ("npy", "spdl"),
192        ("npz", "npz"),
193    ]
194    for data_format, io_func in configs:
195        src = DataSource(get_mock_data(data_format), repeat=1000)
196        load_fn = _get_load_fn(data_format, io_func)
197        for mode in ["mp", "mt"]:
198            for num_workers in [1, 2, 4, 8, 16, 32]:
199                pipeline = get_pipeline(
200                    src,  # pyre-ignore: [6]
201                    load_fn,
202                    num_workers,
203                    mode,
204                )
205                num_items, elapsed = run_pipeline(pipeline)
206                qps = num_items / elapsed
207                print(f"{data_format},{io_func},{mode},{num_workers},{qps}")
208
209
210if __name__ == "__main__":
211    main()

Functions

Functions

main() None[source]

The entrypoint from CLI.

get_mock_data(format: str) tuple[bytes, bytes] | bytes[source]

Generate a single sample in the given format.

The mock data resemboles an RGB image and its segmentation labels.

Parameters:

format – One of "npz", "npy" or "torch".

Returns:

Serialized mock arrays. If "npy" then arrays are serialized separately. Otherwise arrays are bundled together.

get_pipeline(src: Iterable[T], load_fn: Callable[[T], list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]]], num_workers: int, mode: str) Pipeline[source]

Build a pipeline to iterate the source with the load function in different parallelism.

Parameters:
  • src – The data source.

  • load_fn – The function that loads NumPy NDArray from the source byte string.

  • num_workers – The number of worker threads or processes.

  • mode – The mode of parallelism. The valid values are "mt" (multi-threading) and "mp" (multi-processing).

Returns:

The resulting pipeline.

load_npy(items: list[bytes]) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]][source]

Load arrays from serialized NPY binary strings using numpy.load().

load_npy_spdl(items: list[bytes]) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]][source]

Load arrays from serialized NPY binary strings using spdl.io.load_npy().

load_torch(item: bytes) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]][source]

Load arrays from a serialized PyTorch state dict.

run_pipeline(pipeline: Pipeline[...]) tuple[int, float][source]

Run the pipeline and measure the time.

Classes

Classes

class DataSource(data: T, repeat: int)[source]

Keep yielding the same data given times.

Parameters:
  • data – Data to be yielded.

  • repeat – The number of yields.

__iter__() Iterator[T][source]