Data formats¶
This example benchmarks the speed of loading data in different formats.
See Case Studies / Data Format for the detail of how data format and the loading function affects the performance of the training pipeline.
Source¶
Source
Click here to see the source.
1#!/usr/bin/env python3
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8"""This example benchmarks the speed of loading data in different formats.
9
10See `Case Studies / Data Format <../case_studies/data_format.html>`_ for
11the detail of how data format and the loading function affects
12the performance of the training pipeline.
13"""
14
15__all__ = [
16 "main",
17 "get_mock_data",
18 "get_pipeline",
19 "load_npy",
20 "load_npy_spdl",
21 "load_torch",
22 "run_pipeline",
23 "DataSource",
24]
25
26import time
27from collections.abc import Callable, Iterable, Iterator
28from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
29from io import BytesIO
30from typing import Generic, TypeVar
31
32import numpy as np
33import spdl.io
34import torch
35from numpy.typing import NDArray
36from spdl.pipeline import Pipeline, PipelineBuilder
37
38# pyre-strict
39
40T = TypeVar("T")
41
42
43def get_pipeline(
44 src: Iterable[T],
45 load_fn: Callable[[T], list[NDArray]],
46 num_workers: int,
47 mode: str,
48) -> Pipeline:
49 """Build a pipeline to iterate the source with the load function in different parallelism.
50
51 Args:
52 src: The data source.
53 load_fn: The function that loads NumPy NDArray from the source byte string.
54 num_workers: The number of worker threads or processes.
55 mode: The mode of parallelism. The valid values are ``"mt"`` (multi-threading)
56 and ``"mp"`` (multi-processing).
57
58 Returns:
59 The resulting pipeline.
60 """
61 match mode:
62 case "mt":
63 executor = ThreadPoolExecutor(num_workers)
64 case "mp":
65 executor = ProcessPoolExecutor(num_workers)
66 case _:
67 raise ValueError(f'The `mode` must be either "mt" or "mp". Found: {mode}')
68
69 return (
70 PipelineBuilder()
71 .add_source(src)
72 .pipe(load_fn, concurrency=num_workers, executor=executor)
73 .add_sink(buffer_size=1)
74 .build(num_threads=1)
75 )
76
77
78def load_npy(items: list[bytes]) -> list[NDArray]:
79 """Load arrays from serialized NPY binary strings using :py:func:`numpy.load`."""
80 return [np.load(BytesIO(item), allow_pickle=False) for item in items]
81
82
83def load_npy_spdl(items: list[bytes]) -> list[NDArray]:
84 """Load arrays from serialized NPY binary strings using :py:func:`spdl.io.load_npy`."""
85 return [spdl.io.load_npy(item) for item in items]
86
87
88def load_npz(item: bytes) -> list[NDArray]:
89 """Load arrays from a serialized NPZ binary string using :py:func:`numpy.load`."""
90 data = np.load(BytesIO(item))
91 return list(data.values())
92
93
94def load_torch(item: bytes) -> list[NDArray]:
95 """Load arrays from a serialized PyTorch state dict."""
96 return list(torch.load(BytesIO(item)).values())
97
98
99def _get_load_fn(
100 data_format: str, provider: str
101) -> Callable[[list[bytes]], list[NDArray]] | Callable[[bytes], list[NDArray]]:
102 match data_format:
103 case "torch":
104 return load_torch
105 case "npy":
106 if provider == "spdl":
107 return load_npy_spdl
108 return load_npy
109 case "npz":
110 return load_npz
111 case _:
112 raise ValueError(f"Unexpected data format: {data_format}")
113
114
115class DataSource(Generic[T]):
116 """Keep yielding the same data given times.
117
118 Args:
119 data: Data to be yielded.
120 repeat: The number of yields.
121 """
122
123 def __init__(self, data: T, repeat: int) -> None:
124 self.data = data
125 self.repeat = repeat
126
127 def __iter__(self) -> Iterator[T]:
128 for _ in range(self.repeat):
129 yield self.data
130
131
132def run_pipeline(pipeline: Pipeline[...]) -> tuple[int, float]:
133 """Run the pipeline and measure the time."""
134 t0 = time.monotonic()
135 with pipeline.auto_stop():
136 num_items = 0
137 for _ in pipeline:
138 num_items += 1
139 elapsed = time.monotonic() - t0
140 return num_items, elapsed
141
142
143def _dump_np(arr: NDArray | dict[str, NDArray]) -> bytes:
144 with BytesIO() as buf:
145 if isinstance(arr, dict):
146 np.savez(buf, **arr, allow_pickle=False)
147 else:
148 np.save(buf, arr, allow_pickle=False)
149 buf.seek(0)
150 return buf.read()
151
152
153def _dump_torch(arr: dict[str, NDArray]) -> bytes:
154 with BytesIO() as buf:
155 torch.save({k: torch.from_numpy(v) for k, v in arr.items()}, buf)
156 buf.seek(0)
157 return buf.read()
158
159
160def get_mock_data(format: str) -> tuple[bytes, bytes] | bytes:
161 """Generate a single sample in the given format.
162
163 The mock data resemboles an RGB image and its segmentation labels.
164
165 Args:
166 format: One of ``"npz"``, ``"npy"`` or ``"torch"``.
167
168 Returns:
169 Serialized mock arrays. If ``"npy"`` then arrays are serialized
170 separately. Otherwise arrays are bundled together.
171 """
172 img = np.random.randint(256, size=(3, 640, 480), dtype=np.uint8)
173 lbl = np.random.randint(256, size=(640, 480), dtype=np.uint8)
174
175 match format:
176 case "npz":
177 return _dump_np({"img": img, "lbl": lbl})
178 case "npy":
179 return _dump_np(img), _dump_np(lbl)
180 case "torch":
181 return _dump_torch({"img": img, "lbl": lbl})
182 case _:
183 raise ValueError(f"Unexpected `format`: {format}")
184
185
186def main() -> None:
187 """The entrypoint from CLI."""
188 configs = [
189 ("torch", "torch"),
190 ("npy", "npy"),
191 ("npy", "spdl"),
192 ("npz", "npz"),
193 ]
194 for data_format, io_func in configs:
195 src = DataSource(get_mock_data(data_format), repeat=1000)
196 load_fn = _get_load_fn(data_format, io_func)
197 for mode in ["mp", "mt"]:
198 for num_workers in [1, 2, 4, 8, 16, 32]:
199 pipeline = get_pipeline(
200 src, # pyre-ignore: [6]
201 load_fn,
202 num_workers,
203 mode,
204 )
205 num_items, elapsed = run_pipeline(pipeline)
206 qps = num_items / elapsed
207 print(f"{data_format},{io_func},{mode},{num_workers},{qps}")
208
209
210if __name__ == "__main__":
211 main()
Functions¶
Functions
- get_mock_data(format: str) tuple[bytes, bytes] | bytes [source]¶
Generate a single sample in the given format.
The mock data resemboles an RGB image and its segmentation labels.
- Parameters:
format – One of
"npz"
,"npy"
or"torch"
.- Returns:
Serialized mock arrays. If
"npy"
then arrays are serialized separately. Otherwise arrays are bundled together.
- get_pipeline(src: Iterable[T], load_fn: Callable[[T], list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]]], num_workers: int, mode: str) Pipeline [source]¶
Build a pipeline to iterate the source with the load function in different parallelism.
- Parameters:
src – The data source.
load_fn – The function that loads NumPy NDArray from the source byte string.
num_workers – The number of worker threads or processes.
mode – The mode of parallelism. The valid values are
"mt"
(multi-threading) and"mp"
(multi-processing).
- Returns:
The resulting pipeline.
- load_npy(items: list[bytes]) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]] [source]¶
Load arrays from serialized NPY binary strings using
numpy.load()
.
- load_npy_spdl(items: list[bytes]) list[ndarray[tuple[int, ...], dtype[_ScalarType_co]]] [source]¶
Load arrays from serialized NPY binary strings using
spdl.io.load_npy()
.
Classes¶
Classes