Imagenet classification¶
Benchmark the performance of loading images from local file systems and classifying them using a GPU.
This script builds the data loader and instantiates an image classification model in a GPU. The data loader transfers the batch image data to the GPU concurrently, and the foreground thread run the model on data one by one.
To run the benchmark, pass it to the script like the following.
python imagenet_classification.py
--root-dir ~/imagenet/
--split val
Source¶
Source
Click here to see the source.
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""Benchmark the performance of loading images from local file systems and
8classifying them using a GPU.
9
10This script builds the data loader and instantiates an image
11classification model in a GPU.
12The data loader transfers the batch image data to the GPU concurrently, and
13the foreground thread run the model on data one by one.
14
15.. include:: ../plots/imagenet_classification_chart.txt
16
17To run the benchmark, pass it to the script like the following.
18
19.. code-block::
20
21 python imagenet_classification.py
22 --root-dir ~/imagenet/
23 --split val
24"""
25
26# pyre-strict
27
28import argparse
29import contextlib
30import logging
31import time
32from argparse import Namespace
33from collections.abc import Callable, Iterator
34from pathlib import Path
35
36import spdl.io
37import spdl.io.utils
38import torch
39from spdl.dataloader import DataLoader
40from spdl.source.imagenet import ImageNet
41from torch import Tensor
42from torch.profiler import profile
43
44_LG: logging.Logger = logging.getLogger(__name__)
45
46
47__all__ = [
48 "entrypoint",
49 "benchmark",
50 "get_decode_func",
51 "get_dataloader",
52 "get_model",
53 "ModelBundle",
54 "Classification",
55 "Preprocessing",
56]
57
58
59def _parse_args(args: list[str] | None) -> Namespace:
60 parser = argparse.ArgumentParser(
61 description=__doc__,
62 formatter_class=argparse.RawDescriptionHelpFormatter,
63 )
64 parser.add_argument("--debug", action="store_true")
65 parser.add_argument("--root-dir", type=Path, required=True)
66 parser.add_argument("--max-batches", type=int, default=float("inf"))
67 parser.add_argument("--batch-size", type=int, default=32)
68 parser.add_argument("--split", default="val", choices=["train", "val"])
69 parser.add_argument("--trace", type=Path)
70 parser.add_argument("--buffer-size", type=int, default=16)
71 parser.add_argument("--num-threads", type=int, default=16)
72 parser.add_argument("--no-compile", action="store_false", dest="compile")
73 parser.add_argument("--no-bf16", action="store_false", dest="use_bf16")
74 parser.add_argument("--use-nvjpeg", action="store_true")
75 ns = parser.parse_args(args)
76 if ns.trace:
77 ns.max_batches = 60
78 return ns
79
80
81# Handroll the transforms so as to support `torch.compile`
82class Preprocessing(torch.nn.Module):
83 """Perform pixel normalization and data type conversion.
84
85 Args:
86 mean: The mean value of the dataset.
87 std: The standard deviation of the dataset.
88 """
89
90 def __init__(self, mean: Tensor, std: Tensor) -> None:
91 super().__init__()
92 self.register_buffer("mean", mean)
93 self.register_buffer("std", std)
94
95 def forward(self, x: Tensor) -> Tensor:
96 """Normalize the given image batch.
97
98 Args:
99 x: The input image batch. Pixel values are expected to be
100 in the range of ``[0, 255]``.
101 Returns:
102 The normalized image batch.
103 """
104 x = x.float() / 255.0
105 # pyrefly: ignore [unsupported-operation]
106 return (x - self.mean) / self.std
107
108
109class Classification(torch.nn.Module):
110 """Classification()"""
111
112 def forward(self, x: Tensor, labels: Tensor) -> tuple[Tensor, Tensor]:
113 """Given a batch of features and labels, compute the top1 and top5 accuracy.
114
115 Args:
116 images: A batch of images. The shape is ``(batch_size, 3, 224, 224)``.
117 labels: A batch of labels. The shape is ``(batch_size,)``.
118
119 Returns:
120 A tuple of top1 and top5 accuracy.
121 """
122
123 probs = torch.nn.functional.softmax(x, dim=-1)
124 top_prob, top_catid = torch.topk(probs, 5)
125 top1 = (top_catid[:, :1] == labels).sum()
126 top5 = (top_catid == labels).sum()
127 return top1, top5
128
129
130class ModelBundle(torch.nn.Module):
131 """ModelBundle()
132
133 Bundle the transform, model backbone, and classification head into a single module
134 for a simple handling."""
135
136 def __init__(
137 self,
138 model: torch.nn.Module,
139 preprocessing: Preprocessing,
140 classification: Classification,
141 use_bf16: bool,
142 ) -> None:
143 super().__init__()
144 self.model = model
145 self.preprocessing = preprocessing
146 self.classification = classification
147 self.use_bf16 = use_bf16
148
149 def forward(self, images: Tensor, labels: Tensor) -> tuple[Tensor, Tensor]:
150 """Given a batch of images and labels, compute the top1, top5 accuracy.
151
152 Args:
153 images: A batch of images. The shape is ``(batch_size, 3, 224, 224)``.
154 labels: A batch of labels. The shape is ``(batch_size,)``.
155
156 Returns:
157 A tuple of top1 and top5 accuracy.
158 """
159
160 x = self.preprocessing(images)
161
162 if self.use_bf16:
163 x = x.to(torch.bfloat16)
164
165 output = self.model(x)
166
167 return self.classification(output, labels)
168
169
170def _expand(vals: list[float], batch_size: int, res: int) -> Tensor:
171 return torch.tensor(vals).view(1, 3, 1, 1).expand(batch_size, 3, res, res).clone()
172
173
174def get_model(
175 batch_size: int,
176 device_index: int,
177 compile: bool,
178 use_bf16: bool,
179 model_type: str = "mobilenetv3_large_100",
180) -> ModelBundle:
181 """Build computation model, including transfor, model, and classification head.
182
183 Args:
184 batch_size: The batch size of the input.
185 device_index: The index of the target GPU device.
186 compile: Whether to compile the model.
187 use_bf16: Whether to use bfloat16 for the model.
188 model_type: The type of the model. Passed to ``timm.create_model()``.
189
190 Returns:
191 The resulting computation model.
192 """
193 import timm
194
195 device = torch.device(f"cuda:{device_index}")
196
197 model = timm.create_model(model_type, pretrained=True)
198 model = model.eval().to(device=device)
199
200 if use_bf16:
201 model = model.to(dtype=torch.bfloat16)
202
203 preprocessing = Preprocessing(
204 mean=_expand([0.4850, 0.4560, 0.4060], batch_size, 224),
205 std=_expand([0.2290, 0.2240, 0.2250], batch_size, 224),
206 ).to(device)
207
208 classification = Classification().to(device)
209
210 if compile:
211 with torch.no_grad():
212 mode = "max-autotune"
213 model = torch.compile(model, mode=mode)
214 preprocessing = torch.compile(preprocessing, mode=mode)
215
216 return ModelBundle(model, preprocessing, classification, use_bf16) # pyre-ignore[6]
217
218
219def get_decode_func(
220 device_index: int,
221 width: int = 224,
222 height: int = 224,
223) -> Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]]:
224 """Get a function to decode images from a list of paths.
225
226 Args:
227 device_index: The index of the target GPU device.
228 width: The width of the decoded image.
229 height: The height of the decoded image.
230
231 Returns:
232 Async function to decode images in to batch tensor of NCHW format
233 and labels of shape ``(batch_size, 1)``.
234 """
235 device: torch.device = torch.device(f"cuda:{device_index}")
236
237 filter_desc: str | None = spdl.io.get_video_filter_desc(
238 scale_width=256,
239 scale_height=256,
240 crop_width=width,
241 crop_height=height,
242 pix_fmt="rgb24",
243 )
244
245 def decode_images(items: list[tuple[str, int]]) -> tuple[Tensor, Tensor]:
246 paths = [item for item, _ in items]
247 labels = [[item] for _, item in items]
248 labels = torch.tensor(labels, dtype=torch.int64).to(device)
249 buffer = spdl.io.load_image_batch(
250 paths,
251 width=None,
252 height=None,
253 pix_fmt=None,
254 strict=True,
255 filter_desc=filter_desc,
256 device_config=spdl.io.cuda_config(
257 device_index=0,
258 allocator=(
259 torch.cuda.caching_allocator_alloc,
260 torch.cuda.caching_allocator_delete,
261 ),
262 ),
263 )
264 batch = spdl.io.to_torch(buffer)
265 batch = batch.permute((0, 3, 1, 2))
266 return batch, labels
267
268 return decode_images
269
270
271def _get_experimental_nvjpeg_decode_function(
272 device_index: int,
273 width: int = 224,
274 height: int = 224,
275) -> Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]]:
276 device: torch.device = torch.device(f"cuda:{device_index}")
277 device_config: spdl.io.CUDAConfig = spdl.io.cuda_config(
278 device_index=device_index,
279 allocator=(
280 torch.cuda.caching_allocator_alloc,
281 torch.cuda.caching_allocator_delete,
282 ),
283 )
284
285 def decode_images_nvjpeg(
286 items: list[tuple[str, int]],
287 ) -> tuple[Tensor, Tensor]:
288 paths = [item for item, _ in items]
289 labels = [[item] for _, item in items]
290 labels = torch.tensor(labels, dtype=torch.int64).to(device)
291 buffer = spdl.io.load_image_batch_nvjpeg(
292 paths,
293 device_config=device_config,
294 width=width,
295 height=height,
296 pix_fmt="rgb",
297 # strict=True,
298 )
299 batch = spdl.io.to_torch(buffer)
300 return batch, labels
301
302 return decode_images_nvjpeg
303
304
305def get_dataloader(
306 src: Iterator[tuple[str, int]],
307 batch_size: int,
308 decode_func: Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]],
309 buffer_size: int,
310 num_threads: int,
311) -> Iterator[tuple[Tensor, Tensor]]:
312 """Build the dataloader for the ImageNet classification task.
313
314 The dataloader uses the ``decode_func`` for decoding images concurrently and
315 send the resulting data to GPU.
316
317 Args:
318 src: The source of the data. See :py:func:`source`.
319 batch_size: The number of images in a batch.
320 decode_func: The function to decode images.
321 buffer_size: The size of the buffer for the dataloader sink
322 num_threads: The number of worker threads.
323
324 """
325 return DataLoader( # pyre-ignore[7]
326 src,
327 batch_size=batch_size,
328 drop_last=True,
329 aggregator=decode_func,
330 buffer_size=buffer_size,
331 num_threads=num_threads,
332 timeout=20,
333 )
334
335
336def benchmark(
337 dataloader: Iterator[tuple[Tensor, Tensor]],
338 model: ModelBundle,
339 max_batches: float = float("nan"),
340) -> None:
341 """The main loop that measures the performance of dataloading and model inference.
342
343 Args:
344 loader: The dataloader to benchmark.
345 model: The model to benchmark.
346 max_batches: The number of batch before stopping.
347 """
348
349 _LG.info("Running inference.")
350 num_frames, num_correct_top1, num_correct_top5 = 0, 0, 0
351 t0 = time.monotonic()
352 try:
353 for i, (batch, labels) in enumerate(dataloader):
354 if i == 20:
355 t0 = time.monotonic()
356 num_frames, num_correct_top1, num_correct_top5 = 0, 0, 0
357
358 with (
359 torch.profiler.record_function(f"iter_{i}"),
360 spdl.io.utils.trace_event(f"iter_{i}"),
361 ):
362 top1, top5 = model(batch, labels)
363
364 num_frames += batch.shape[0]
365 num_correct_top1 += top1
366 num_correct_top5 += top5
367
368 if i + 1 >= max_batches:
369 break
370 finally:
371 elapsed = time.monotonic() - t0
372 if num_frames != 0:
373 num_correct_top1 = num_correct_top1.item() # pyre-ignore[16]
374 # pyrefly: ignore [missing-attribute]
375 num_correct_top5 = num_correct_top5.item()
376 fps = num_frames / elapsed
377 _LG.info(f"FPS={fps:.2f} ({num_frames}/{elapsed:.2f})")
378 acc1 = 0 if num_frames == 0 else num_correct_top1 / num_frames
379 _LG.info(f"Accuracy (top1)={acc1:.2%} ({num_correct_top1}/{num_frames})")
380 acc5 = 0 if num_frames == 0 else num_correct_top5 / num_frames
381 _LG.info(f"Accuracy (top5)={acc5:.2%} ({num_correct_top5}/{num_frames})")
382
383
384def _get_dataloader(
385 args: Namespace, device_index: int
386) -> Iterator[tuple[Tensor, Tensor]]:
387 src = ImageNet(args.root_dir, split=args.split)
388
389 if args.use_nvjpeg:
390 decode_func = _get_experimental_nvjpeg_decode_function(device_index)
391 else:
392 decode_func = get_decode_func(device_index)
393
394 return get_dataloader(
395 src, # pyre-ignore[6]
396 args.batch_size,
397 decode_func,
398 args.buffer_size,
399 args.num_threads,
400 )
401
402
403def entrypoint(args_: list[str] | None = None) -> None:
404 """CLI entrypoint. Run pipeline, transform and model and measure its performance."""
405
406 args = _parse_args(args_)
407 _init_logging(args.debug)
408 _LG.info(args)
409
410 device_index = 0
411 model = get_model(args.batch_size, device_index, args.compile, args.use_bf16)
412 dataloader = _get_dataloader(args, device_index)
413
414 trace_path = f"{args.trace}"
415 if args.use_nvjpeg:
416 trace_path = f"{trace_path}.nvjpeg"
417
418 with (
419 torch.no_grad(),
420 profile() if args.trace else contextlib.nullcontext() as prof,
421 spdl.io.utils.tracing(f"{trace_path}.pftrace", enable=args.trace is not None),
422 ):
423 benchmark(dataloader, model, args.max_batches)
424
425 if args.trace:
426 # pyrefly: ignore [missing-attribute]
427 prof.export_chrome_trace(f"{trace_path}.json")
428
429
430def _init_logging(debug: bool = False) -> None:
431 fmt = "%(asctime)s [%(filename)s:%(lineno)d] [%(levelname)s] %(message)s"
432 level = logging.DEBUG if debug else logging.INFO
433 logging.basicConfig(format=fmt, level=level)
434
435
436if __name__ == "__main__":
437 entrypoint()
API Reference¶
Functions
- entrypoint(args_: list[str] | None = None) None[source]¶
CLI entrypoint. Run pipeline, transform and model and measure its performance.
- benchmark(dataloader: Iterator[tuple[Tensor, Tensor]], model: ModelBundle, max_batches: float = nan) None[source]¶
The main loop that measures the performance of dataloading and model inference.
- Parameters:
loader – The dataloader to benchmark.
model – The model to benchmark.
max_batches – The number of batch before stopping.
- get_decode_func(device_index: int, width: int = 224, height: int = 224) Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]][source]¶
Get a function to decode images from a list of paths.
- Parameters:
device_index – The index of the target GPU device.
width – The width of the decoded image.
height – The height of the decoded image.
- Returns:
Async function to decode images in to batch tensor of NCHW format and labels of shape
(batch_size, 1).
- get_dataloader(src: Iterator[tuple[str, int]], batch_size: int, decode_func: Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]], buffer_size: int, num_threads: int) Iterator[tuple[Tensor, Tensor]][source]¶
Build the dataloader for the ImageNet classification task.
The dataloader uses the
decode_funcfor decoding images concurrently and send the resulting data to GPU.- Parameters:
src – The source of the data. See
source().batch_size – The number of images in a batch.
decode_func – The function to decode images.
buffer_size – The size of the buffer for the dataloader sink
num_threads – The number of worker threads.
- get_model(batch_size: int, device_index: int, compile: bool, use_bf16: bool, model_type: str = 'mobilenetv3_large_100') ModelBundle[source]¶
Build computation model, including transfor, model, and classification head.
- Parameters:
batch_size – The batch size of the input.
device_index – The index of the target GPU device.
compile – Whether to compile the model.
use_bf16 – Whether to use bfloat16 for the model.
model_type – The type of the model. Passed to
timm.create_model().
- Returns:
The resulting computation model.
Classes
- class ModelBundle[source]¶
Bundle the transform, model backbone, and classification head into a single module for a simple handling.
- forward(images: Tensor, labels: Tensor) tuple[Tensor, Tensor][source]¶
Given a batch of images and labels, compute the top1, top5 accuracy.
- Parameters:
images – A batch of images. The shape is
(batch_size, 3, 224, 224).labels – A batch of labels. The shape is
(batch_size,).
- Returns:
A tuple of top1 and top5 accuracy.
- class Classification[source]¶
- forward(x: Tensor, labels: Tensor) tuple[Tensor, Tensor][source]¶
Given a batch of features and labels, compute the top1 and top5 accuracy.
- Parameters:
images – A batch of images. The shape is
(batch_size, 3, 224, 224).labels – A batch of labels. The shape is
(batch_size,).
- Returns:
A tuple of top1 and top5 accuracy.