multi_thread_preprocessing¶
This example shows how to run PyTorch tarnsform in SPDL Pipeline, and compares its performance against PyTorch DataLoader.
Each pipeline reads images from the ImageNet dataset, and applies resize, batching, and pixel normalization then the data is transferred to GPU.
In the PyTorch and TorchVision native solution, the images are decoded
and resized using Pillow, batched with torch.utils.data.default_collate()
,
pixel normalization is applied with torchvision.transforms.Normalize
,
and data are transferred to GPU with torch.Tensor.cuda()
.
Using torch.utils.data.DataLoader
, the batch is created and
normalized in subprocess and transferred to the main process before they are
sent to GPU.
The following diagram illustrates this.
On the other hand, SPDL Pipeline executes the transforms in the main process. SPDL pipeline uses its own implementation for decode, resize and batching image data.
This script runs the pipeline with different configurations described bellow while changing the number of workers.
Image decoding and resizing
Image decoding, resizing, and batching
Image decoding, resizing, batching, and normalization
Image decoding, resizing, batching, normalization, and transfer to GPU
The following result was obtained.
The following observations can be made.
In both implementations, the throughput peaks around 16 workers, and then decreases as the number of workers.
The throughput increases when batching images, then decreases as additional processing is added.
The degree of improvement from batching in SPDL is significantly higher than in PyTorch. (more than 2x at 16 workers.)
The peak througput is almost 2.7x in SPDL than in PyTorch.
Source¶
Source
Click here to see the source.
1#!/usr/bin/env python3
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10"""This example shows how to run PyTorch tarnsform in SPDL Pipeline,
11and compares its performance against PyTorch DataLoader.
12
13Each pipeline reads images from the ImageNet dataset, and applies
14resize, batching, and pixel normalization then the data is transferred
15to GPU.
16
17In the PyTorch and TorchVision native solution, the images are decoded
18and resized using Pillow, batched with :py:func:`torch.utils.data.default_collate`,
19pixel normalization is applied with :py:class:`torchvision.transforms.Normalize`,
20and data are transferred to GPU with :py:func:`torch.Tensor.cuda`.
21
22Using :py:class:`torch.utils.data.DataLoader`, the batch is created and
23normalized in subprocess and transferred to the main process before they are
24sent to GPU.
25
26The following diagram illustrates this.
27
28.. include:: ../plots/multi_thread_preprocessing_chart_torch.txt
29
30On the other hand, SPDL Pipeline executes the transforms in the main process.
31SPDL pipeline uses its own implementation for decode, resize and batching image data.
32
33.. include:: ../plots/multi_thread_preprocessing_chart_spdl.txt
34
35This script runs the pipeline with different configurations described bellow while
36changing the number of workers.
37
381. Image decoding and resizing
392. Image decoding, resizing, and batching
403. Image decoding, resizing, batching, and normalization
414. Image decoding, resizing, batching, normalization, and transfer to GPU
42
43The following result was obtained.
44
45.. include:: ../plots/multi_thread_preprocessing_plot.txt
46
47The following observations can be made.
48
49- In both implementations, the throughput peaks around 16 workers,
50 and then decreases as the number of workers.
51- The throughput increases when batching images, then decreases
52 as additional processing is added.
53- The degree of improvement from batching in SPDL is significantly
54 higher than in PyTorch. (more than 2x at 16 workers.)
55- The peak througput is almost 2.7x in SPDL than in PyTorch.
56"""
57
58import logging
59import multiprocessing
60import time
61from collections.abc import Iterable
62from multiprocessing import Process, Queue
63
64import spdl.io
65import torch
66from spdl.pipeline import PipelineBuilder
67from torchvision.datasets import ImageNet
68from torchvision.transforms import Compose, Normalize, PILToTensor, Resize
69
70__all__ = [
71 "entrypoint",
72 "exp_torch",
73 "exp_spdl",
74 "run_dataloader",
75]
76
77
78logging.getLogger().setLevel(logging.ERROR)
79
80
81def run_dataloader(
82 dataloader: Iterable,
83 max_items: int,
84) -> tuple[int, float]:
85 """Run the given dataloader and measure its performance.
86
87 Args:
88 dataloader: The dataloader to benchmark.
89 max_items: The maximum number of items to process.
90
91 Returns:
92 The number of items processed and the elapsed time in seconds.
93 """
94 num_items = 0
95 t0 = time.monotonic()
96 try:
97 for i, (data, _) in enumerate(dataloader, start=1):
98 num_items += 1 if data.ndim == 3 else len(data)
99 if i >= max_items:
100 break
101 finally:
102 elapsed = time.monotonic() - t0
103 return num_items, elapsed
104
105
106def exp_torch(
107 *,
108 root_dir: str,
109 split: str,
110 num_workers: int,
111 max_items: int,
112 batch_size: int | None = None,
113 normalize: bool = False,
114 transfer: bool = False,
115) -> tuple[int, float]:
116 """Load data with PyTorch native operation using PyTorch DataLoader.
117
118 This is the baseline for comparison.
119
120 Args:
121 root_dir: The root directory of the ImageNet dataset.
122 split: The dataset split, such as "train" and "val".
123 num_workers: The number of workers to use.
124 max_items: The maximum number of items to process.
125 batch: Whether to batch the data.
126 normalize: Whether to normalize the data. Only applicable when ``batch`` is True.
127 transfer: Whether to transfer the data to GPU.
128
129 Returns:
130 The number of items processed and the elapsed time in seconds.
131 """
132 dataset = ImageNet(
133 root=root_dir,
134 split=split,
135 transform=Compose([Resize((224, 224)), PILToTensor()]),
136 )
137
138 normalize_transform = Normalize(
139 mean=[0.485, 0.456, 0.406],
140 std=[0.229, 0.224, 0.225],
141 )
142
143 def collate(item):
144 batch, cls = torch.utils.data.default_collate(item)
145 if normalize:
146 batch = batch.float() / 255
147 batch = normalize_transform(batch)
148 return batch, cls
149
150 dataloader = torch.utils.data.DataLoader(
151 dataset,
152 batch_size=batch_size,
153 num_workers=num_workers,
154 collate_fn=None if batch_size is None else collate,
155 prefetch_factor=1,
156 multiprocessing_context="fork",
157 )
158
159 if transfer:
160
161 def with_transfer(dataloader):
162 for tensor, cls in dataloader:
163 tensor = tensor.cuda()
164 yield tensor, cls
165
166 dataloader = with_transfer(dataloader)
167
168 with torch.no_grad():
169 return run_dataloader(dataloader, max_items)
170
171
172def exp_spdl(
173 *,
174 root_dir: str,
175 split: str,
176 num_workers: int,
177 max_items: int,
178 batch_size: int | None = None,
179 normalize: bool = False,
180 transfer: bool = False,
181) -> tuple[int, float]:
182 """Load data with SPDL operation using SPDL Pipeline.
183
184 Args:
185 root_dir: The root directory of the ImageNet dataset.
186 split: The dataset split, such as "train" and "val".
187 num_workers: The number of workers to use.
188 max_items: The maximum number of items to process.
189 batch: Whether to batch the data.
190 normalize: Whether to normalize the data. Only applicable when ``batch`` is True.
191 transfer: Whether to transfer the data to GPU.
192
193 Returns:
194 The number of items processed and the elapsed time in seconds.
195 """
196 filter_desc = spdl.io.get_video_filter_desc(
197 scale_width=224,
198 scale_height=224,
199 )
200
201 def decode_image(path):
202 packets = spdl.io.demux_image(path)
203 return spdl.io.decode_packets(packets, filter_desc=filter_desc)
204
205 dataset = ImageNet(
206 root=root_dir,
207 split=split,
208 loader=decode_image,
209 )
210
211 def convert(items):
212 frames, cls = list(zip(*items))
213 buffer = spdl.io.convert_frames(frames)
214 tensor = spdl.io.to_torch(buffer).permute(0, 3, 1, 2)
215 return tensor, cls
216
217 builder = (
218 PipelineBuilder()
219 .add_source(range(len(dataset)))
220 .pipe(dataset.__getitem__, concurrency=num_workers)
221 .aggregate(batch_size or 1)
222 .pipe(convert)
223 )
224
225 if normalize:
226 transform = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
227
228 def normalize(item):
229 tensor, cls = item
230 tensor = tensor.float() / 255
231 tensor = transform(tensor)
232 return tensor, cls
233
234 builder = builder.pipe(normalize)
235
236 if transfer:
237 builder = builder.pipe(lambda item: (item[0].cuda(), item[1]))
238
239 builder = builder.add_sink(num_workers)
240 pipeline = builder.build(num_threads=num_workers)
241
242 with torch.no_grad(), pipeline.auto_stop():
243 return run_dataloader(pipeline, max_items)
244
245
246##############################################################################
247# Execute the test function in subprocess, so as to isolate them
248##############################################################################
249def exp_torch_(queue, **kwargs):
250 queue.put(exp_torch(**kwargs))
251
252
253def exp_spdl_(queue, **kwargs):
254 queue.put(exp_spdl(**kwargs))
255
256
257def run_in_process(func, **kwargs):
258 queue = Queue()
259 Process(target=func, args=[queue], kwargs=kwargs).run()
260 return queue.get()
261
262
263def run_test(**kwargs):
264 data = {}
265 num_workers_ = [1, 2, 4, 8, 16, 32]
266 for func in [exp_torch_, exp_spdl_]: # exp_torch_thread, exp_spdl]:
267 print(func.__name__)
268 print("\tnum_workers\tFPS")
269 y = []
270 for num_workers in num_workers_:
271 num_images, elapsed = run_in_process(
272 func, num_workers=num_workers, **kwargs
273 )
274 qps = num_images / elapsed
275 y.append(qps)
276 print(f"\t{num_workers}\t{qps:8.2f} ({num_images} / {elapsed:5.2f})")
277
278 data[func.__name__] = (num_workers_, y)
279
280 return data
281
282
283def _print(data):
284 for i, (x, y) in enumerate(data.values()):
285 if i == 0:
286 print("\t".join(str(v) for v in x))
287 print("\t".join(f"{v:.2f}" for v in y))
288
289
290def entrypoint(
291 root_dir: str,
292 split: str,
293 batch_size: int,
294 max_items: int,
295):
296 """The main entrypoint for CLI.
297
298 Args:
299 root_dir: The root directory of the ImageNet dataset.
300 split: Dataset split, such as "train" and "val".
301 batch_size: The batch size to use.
302 max_items: The maximum number of items to process.
303 """
304 multiprocessing.set_start_method("spawn")
305
306 argset = (
307 {"batch_size": None},
308 {"batch_size": batch_size},
309 {"batch_size": batch_size, "normalize": True},
310 {"batch_size": batch_size, "normalize": True, "transfer": True},
311 )
312
313 for kwargs in argset:
314 print(kwargs)
315 data = run_test(root_dir=root_dir, split=split, max_items=max_items, **kwargs)
316 _print(data)
317
318
319def _parse_args():
320 import argparse
321
322 parser = argparse.ArgumentParser()
323 parser.add_argument(
324 "--root-dir",
325 help="Directory where the ImageNet dataset is stored.",
326 default="/home/moto/local/imagenet/",
327 )
328 parser.add_argument("--batch-size", default=32, type=int)
329 parser.add_argument(
330 "--max-items",
331 type=int,
332 help="The maximum number of items (images or batches) to process.",
333 default=100,
334 )
335 parser.add_argument(
336 "--split",
337 default="val",
338 )
339 return parser.parse_args()
340
341
342if __name__ == "__main__":
343 _args = _parse_args()
344 entrypoint(
345 _args.root_dir,
346 _args.split,
347 _args.batch_size,
348 _args.max_items,
349 )
Functions¶
Functions
- entrypoint(root_dir: str, split: str, batch_size: int, max_items: int)[source]¶
The main entrypoint for CLI.
- Parameters:
root_dir – The root directory of the ImageNet dataset.
split – Dataset split, such as “train” and “val”.
batch_size – The batch size to use.
max_items – The maximum number of items to process.
- exp_torch(*, root_dir: str, split: str, num_workers: int, max_items: int, batch_size: int | None = None, normalize: bool = False, transfer: bool = False) tuple[int, float] [source]¶
Load data with PyTorch native operation using PyTorch DataLoader.
This is the baseline for comparison.
- Parameters:
root_dir – The root directory of the ImageNet dataset.
split – The dataset split, such as “train” and “val”.
num_workers – The number of workers to use.
max_items – The maximum number of items to process.
batch – Whether to batch the data.
normalize – Whether to normalize the data. Only applicable when
batch
is True.transfer – Whether to transfer the data to GPU.
- Returns:
The number of items processed and the elapsed time in seconds.
- exp_spdl(*, root_dir: str, split: str, num_workers: int, max_items: int, batch_size: int | None = None, normalize: bool = False, transfer: bool = False) tuple[int, float] [source]¶
Load data with SPDL operation using SPDL Pipeline.
- Parameters:
root_dir – The root directory of the ImageNet dataset.
split – The dataset split, such as “train” and “val”.
num_workers – The number of workers to use.
max_items – The maximum number of items to process.
batch – Whether to batch the data.
normalize – Whether to normalize the data. Only applicable when
batch
is True.transfer – Whether to transfer the data to GPU.
- Returns:
The number of items processed and the elapsed time in seconds.
- run_dataloader(dataloader: Iterable, max_items: int) tuple[int, float] [source]¶
Run the given dataloader and measure its performance.
- Parameters:
dataloader – The dataloader to benchmark.
max_items – The maximum number of items to process.
- Returns:
The number of items processed and the elapsed time in seconds.