Data formats¶
This example benchmarks the speed of loading data in different formats.
See Case Studies / Data Format for the detail of how data format and the loading function affects the performance of the training pipeline.
Source¶
Source
Click here to see the source.
1#!/usr/bin/env python3
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8"""This example benchmarks the speed of loading data in different formats.
9
10See `Case Studies / Data Format <../case_studies/data_format.html>`_ for
11the detail of how data format and the loading function affects
12the performance of the training pipeline.
13"""
14
15__all__ = [
16 "main",
17 "get_mock_data",
18 "get_pipeline",
19 "load_npy",
20 "load_npy_spdl",
21 "load_torch",
22 "run_pipeline",
23 "DataSource",
24]
25
26import time
27from collections.abc import Callable, Iterable, Iterator
28from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
29from io import BytesIO
30from typing import Generic, TypeVar
31
32import numpy as np
33import spdl.io
34import torch
35from numpy.typing import NDArray
36from spdl.pipeline import Pipeline, PipelineBuilder
37
38# pyre-strict
39
40T = TypeVar("T")
41
42
43def get_pipeline(
44 src: Iterable[T],
45 load_fn: Callable[[T], list[NDArray]],
46 num_workers: int,
47 mode: str,
48) -> Pipeline:
49 """Build a pipeline to iterate the source with the load function in different parallelism.
50
51 Args:
52 src: The data source.
53 load_fn: The function that loads NumPy NDArray from the source byte string.
54 num_workers: The number of worker threads or processes.
55 mode: The mode of parallelism. The valid values are ``"mt"`` (multi-threading)
56 and ``"mp"`` (multi-processing).
57
58 Returns:
59 The resulting pipeline.
60 """
61 match mode:
62 case "mt":
63 executor = ThreadPoolExecutor(num_workers)
64 case "mp":
65 executor = ProcessPoolExecutor(num_workers)
66 case _:
67 raise ValueError(f'The `mode` must be either "mt" or "mp". Found: {mode}')
68
69 return (
70 PipelineBuilder()
71 .add_source(src)
72 .pipe(load_fn, concurrency=num_workers, executor=executor)
73 .add_sink(buffer_size=1)
74 .build(num_threads=1)
75 )
76
77
78def load_npy(items: list[bytes]) -> list[NDArray]:
79 """Load arrays from serialized NPY binary strings using :py:func:`numpy.load`."""
80 return [np.load(BytesIO(item), allow_pickle=False) for item in items]
81
82
83def load_npy_spdl(items: list[bytes]) -> list[NDArray]:
84 """Load arrays from serialized NPY binary strings using :py:func:`spdl.io.load_npy`."""
85 return [spdl.io.load_npy(item) for item in items]
86
87
88def load_npz(item: bytes) -> list[NDArray]:
89 """Load arrays from a serialized NPZ binary string using :py:func:`numpy.load`."""
90 data = np.load(BytesIO(item))
91 return list(data.values())
92
93
94def load_npz_spdl(item: bytes) -> list[NDArray]:
95 """Load arrays from serialized NPZ binary strings using :py:func:`spdl.io.load_npz`."""
96 data = spdl.io.load_npz(item)
97 return list(data.values())
98
99
100def load_torch(item: bytes) -> list[NDArray]:
101 """Load arrays from a serialized PyTorch state dict."""
102 return list(torch.load(BytesIO(item)).values())
103
104
105def _get_load_fn(
106 data_format: str, impl: str
107) -> Callable[[list[bytes]], list[NDArray]] | Callable[[bytes], list[NDArray]]:
108 match data_format:
109 case "torch":
110 return load_torch
111 case "npy":
112 if impl == "spdl":
113 return load_npy_spdl
114 return load_npy
115 case "npz":
116 if impl == "spdl":
117 return load_npz_spdl
118 return load_npz
119 case _:
120 raise ValueError(f"Unexpected data format: {data_format}")
121
122
123class DataSource(Generic[T]):
124 """Keep yielding the same data given times.
125
126 Args:
127 data: Data to be yielded.
128 repeat: The number of yields.
129 """
130
131 def __init__(self, data: T, repeat: int) -> None:
132 self.data = data
133 self.repeat = repeat
134
135 def __iter__(self) -> Iterator[T]:
136 for _ in range(self.repeat):
137 yield self.data
138
139
140def run_pipeline(pipeline: Pipeline[...]) -> tuple[int, float]:
141 """Run the pipeline and measure the time."""
142 t0 = time.monotonic()
143 with pipeline.auto_stop():
144 num_items = 0
145 for _ in pipeline:
146 num_items += 1
147 elapsed = time.monotonic() - t0
148 return num_items, elapsed
149
150
151def _dump_np(arr: NDArray | dict[str, NDArray], compressed: bool = False) -> bytes:
152 with BytesIO() as buf:
153 if isinstance(arr, dict):
154 if compressed:
155 np.savez_compressed(buf, allow_pickle=False, **arr)
156 else:
157 np.savez(buf, allow_pickle=False, **arr)
158 else:
159 np.save(buf, arr, allow_pickle=False)
160 buf.seek(0)
161 return buf.read()
162
163
164def _dump_torch(arr: dict[str, NDArray]) -> bytes:
165 with BytesIO() as buf:
166 torch.save({k: torch.from_numpy(v) for k, v in arr.items()}, buf)
167 buf.seek(0)
168 return buf.read()
169
170
171def get_mock_data(format: str, compressed: bool = False) -> tuple[bytes, bytes] | bytes:
172 """Generate a single sample in the given format.
173
174 The mock data resemboles an RGB image and its segmentation labels.
175
176 Args:
177 format: One of ``"npz"``, ``"npy"`` or ``"torch"``.
178 compressed: If ``True``, NPZ file is compressed.
179 (i.e. :py:func:`numpy.savez_compressed` is used.)
180
181 Returns:
182 Serialized mock arrays. If ``"npy"`` then arrays are serialized
183 separately. Otherwise arrays are bundled together.
184 """
185 img = np.random.randint(256, size=(3, 640, 480), dtype=np.uint8)
186 lbl = np.random.randint(256, size=(640, 480), dtype=np.uint8)
187
188 match format:
189 case "npz":
190 return _dump_np({"img": img, "lbl": lbl}, compressed=compressed)
191 case "npy":
192 return _dump_np(img), _dump_np(lbl)
193 case "torch":
194 return _dump_torch({"img": img, "lbl": lbl})
195 case _:
196 raise ValueError(f"Unexpected `format`: {format}")
197
198
199def main() -> None:
200 """The entrypoint from CLI."""
201 configs = [
202 ("torch", False, "torch"),
203 ("npy", False, "np"),
204 ("npy", False, "spdl"),
205 ("npz", False, "np"),
206 ("npz", True, "np"),
207 ("npz", False, "spdl"),
208 ("npz", True, "spdl"),
209 ]
210 for data_format, compressed, impl in configs:
211 src = DataSource(get_mock_data(data_format, compressed), repeat=1000)
212 load_fn = _get_load_fn(data_format, impl)
213 for mode in ["mp", "mt"]:
214 for num_workers in [1, 2, 4, 8, 16, 32]:
215 pipeline = get_pipeline(
216 src, # pyre-ignore: [6]
217 load_fn,
218 num_workers,
219 mode,
220 )
221 num_items, elapsed = run_pipeline(pipeline)
222 qps = num_items / elapsed
223 print(f"{data_format},{compressed},{impl},{mode},{num_workers},{qps}")
224
225
226if __name__ == "__main__":
227 main()
Functions¶
Functions
- get_mock_data(format: str, compressed: bool = False) tuple[bytes, bytes] | bytes [source]¶
Generate a single sample in the given format.
The mock data resemboles an RGB image and its segmentation labels.
- Parameters:
format – One of
"npz"
,"npy"
or"torch"
.compressed – If
True
, NPZ file is compressed. (i.e.numpy.savez_compressed()
is used.)
- Returns:
Serialized mock arrays. If
"npy"
then arrays are serialized separately. Otherwise arrays are bundled together.
- get_pipeline(src: Iterable[T], load_fn: Callable[[T], list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]]], num_workers: int, mode: str) Pipeline [source]¶
Build a pipeline to iterate the source with the load function in different parallelism.
- Parameters:
src – The data source.
load_fn – The function that loads NumPy NDArray from the source byte string.
num_workers – The number of worker threads or processes.
mode – The mode of parallelism. The valid values are
"mt"
(multi-threading) and"mp"
(multi-processing).
- Returns:
The resulting pipeline.
- load_npy(items: list[bytes]) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]] [source]¶
Load arrays from serialized NPY binary strings using
numpy.load()
.
- load_npy_spdl(items: list[bytes]) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]] [source]¶
Load arrays from serialized NPY binary strings using
spdl.io.load_npy()
.
Classes¶
Classes