Imagenet classification

Benchmark the performance of loading images from local file systems and classifying them using a GPU.

This script builds the data loader and instantiates an image classification model in a GPU. The data loader transfers the batch image data to the GPU concurrently, and the foreground thread run the model on data one by one.

flowchart LR subgraph MP [Main Process] subgraph BG [Background Thread] A[Source] subgraph TP1[Thread Pool] direction LR T1[Thread] T2[Thread] T3[Thread] end end subgraph FG [Main Thread] ML[Main loop] end end subgraph G[GPU] direction TB GM[Memory] T[Transform] M[Model] end A --> T1 -- Batch --> GM A --> T2 -- Batch --> GM A --> T3 -- Batch --> GM ML -.-> GM GM -.-> T -.-> M

To run the benchmark, pass it to the script like the following.

python imagenet_classification.py
    --root-dir ~/imagenet/
    --split val

Source

Source

Click here to see the source.
  1# Copyright (c) Meta Platforms, Inc. and affiliates.
  2# All rights reserved.
  3#
  4# This source code is licensed under the BSD-style license found in the
  5# LICENSE file in the root directory of this source tree.
  6
  7"""Benchmark the performance of loading images from local file systems and
  8classifying them using a GPU.
  9
 10This script builds the data loader and instantiates an image
 11classification model in a GPU.
 12The data loader transfers the batch image data to the GPU concurrently, and
 13the foreground thread run the model on data one by one.
 14
 15.. include:: ../plots/imagenet_classification_chart.txt
 16
 17To run the benchmark,  pass it to the script like the following.
 18
 19.. code-block::
 20
 21   python imagenet_classification.py
 22       --root-dir ~/imagenet/
 23       --split val
 24"""
 25
 26# pyre-strict
 27
 28import argparse
 29import contextlib
 30import logging
 31import time
 32from argparse import Namespace
 33from collections.abc import Callable, Iterator
 34from pathlib import Path
 35
 36import spdl.io
 37import spdl.io.utils
 38import torch
 39from spdl.dataloader import DataLoader
 40from spdl.source.imagenet import ImageNet
 41from torch import Tensor
 42from torch.profiler import profile
 43
 44_LG: logging.Logger = logging.getLogger(__name__)
 45
 46
 47__all__ = [
 48    "entrypoint",
 49    "benchmark",
 50    "get_decode_func",
 51    "get_dataloader",
 52    "get_model",
 53    "ModelBundle",
 54    "Classification",
 55    "Preprocessing",
 56]
 57
 58
 59def _parse_args(args: list[str] | None) -> Namespace:
 60    parser = argparse.ArgumentParser(
 61        description=__doc__,
 62        formatter_class=argparse.RawDescriptionHelpFormatter,
 63    )
 64    parser.add_argument("--debug", action="store_true")
 65    parser.add_argument("--root-dir", type=Path, required=True)
 66    parser.add_argument("--max-batches", type=int, default=float("inf"))
 67    parser.add_argument("--batch-size", type=int, default=32)
 68    parser.add_argument("--split", default="val", choices=["train", "val"])
 69    parser.add_argument("--trace", type=Path)
 70    parser.add_argument("--buffer-size", type=int, default=16)
 71    parser.add_argument("--num-threads", type=int, default=16)
 72    parser.add_argument("--no-compile", action="store_false", dest="compile")
 73    parser.add_argument("--no-bf16", action="store_false", dest="use_bf16")
 74    parser.add_argument("--use-nvjpeg", action="store_true")
 75    ns = parser.parse_args(args)
 76    if ns.trace:
 77        ns.max_batches = 60
 78    return ns
 79
 80
 81# Handroll the transforms so as to support `torch.compile`
 82class Preprocessing(torch.nn.Module):
 83    """Perform pixel normalization and data type conversion.
 84
 85    Args:
 86        mean: The mean value of the dataset.
 87        std: The standard deviation of the dataset.
 88    """
 89
 90    def __init__(self, mean: Tensor, std: Tensor) -> None:
 91        super().__init__()
 92        self.register_buffer("mean", mean)
 93        self.register_buffer("std", std)
 94
 95    def forward(self, x: Tensor) -> Tensor:
 96        """Normalize the given image batch.
 97
 98        Args:
 99            x: The input image batch. Pixel values are expected to be
100                in the range of ``[0, 255]``.
101        Returns:
102            The normalized image batch.
103        """
104        x = x.float() / 255.0
105        # pyrefly: ignore [unsupported-operation]
106        return (x - self.mean) / self.std
107
108
109class Classification(torch.nn.Module):
110    """Classification()"""
111
112    def forward(self, x: Tensor, labels: Tensor) -> tuple[Tensor, Tensor]:
113        """Given a batch of features and labels, compute the top1 and top5 accuracy.
114
115        Args:
116            images: A batch of images. The shape is ``(batch_size, 3, 224, 224)``.
117            labels: A batch of labels. The shape is ``(batch_size,)``.
118
119        Returns:
120            A tuple of top1 and top5 accuracy.
121        """
122
123        probs = torch.nn.functional.softmax(x, dim=-1)
124        top_prob, top_catid = torch.topk(probs, 5)
125        top1 = (top_catid[:, :1] == labels).sum()
126        top5 = (top_catid == labels).sum()
127        return top1, top5
128
129
130class ModelBundle(torch.nn.Module):
131    """ModelBundle()
132
133    Bundle the transform, model backbone, and classification head into a single module
134    for a simple handling."""
135
136    def __init__(
137        self,
138        model: torch.nn.Module,
139        preprocessing: Preprocessing,
140        classification: Classification,
141        use_bf16: bool,
142    ) -> None:
143        super().__init__()
144        self.model = model
145        self.preprocessing = preprocessing
146        self.classification = classification
147        self.use_bf16 = use_bf16
148
149    def forward(self, images: Tensor, labels: Tensor) -> tuple[Tensor, Tensor]:
150        """Given a batch of images and labels, compute the top1, top5 accuracy.
151
152        Args:
153            images: A batch of images. The shape is ``(batch_size, 3, 224, 224)``.
154            labels: A batch of labels. The shape is ``(batch_size,)``.
155
156        Returns:
157            A tuple of top1 and top5 accuracy.
158        """
159
160        x = self.preprocessing(images)
161
162        if self.use_bf16:
163            x = x.to(torch.bfloat16)
164
165        output = self.model(x)
166
167        return self.classification(output, labels)
168
169
170def _expand(vals: list[float], batch_size: int, res: int) -> Tensor:
171    return torch.tensor(vals).view(1, 3, 1, 1).expand(batch_size, 3, res, res).clone()
172
173
174def get_model(
175    batch_size: int,
176    device_index: int,
177    compile: bool,
178    use_bf16: bool,
179    model_type: str = "mobilenetv3_large_100",
180) -> ModelBundle:
181    """Build computation model, including transfor, model, and classification head.
182
183    Args:
184        batch_size: The batch size of the input.
185        device_index: The index of the target GPU device.
186        compile: Whether to compile the model.
187        use_bf16: Whether to use bfloat16 for the model.
188        model_type: The type of the model. Passed to ``timm.create_model()``.
189
190    Returns:
191        The resulting computation model.
192    """
193    import timm
194
195    device = torch.device(f"cuda:{device_index}")
196
197    model = timm.create_model(model_type, pretrained=True)
198    model = model.eval().to(device=device)
199
200    if use_bf16:
201        model = model.to(dtype=torch.bfloat16)
202
203    preprocessing = Preprocessing(
204        mean=_expand([0.4850, 0.4560, 0.4060], batch_size, 224),
205        std=_expand([0.2290, 0.2240, 0.2250], batch_size, 224),
206    ).to(device)
207
208    classification = Classification().to(device)
209
210    if compile:
211        with torch.no_grad():
212            mode = "max-autotune"
213            model = torch.compile(model, mode=mode)
214            preprocessing = torch.compile(preprocessing, mode=mode)
215
216    return ModelBundle(model, preprocessing, classification, use_bf16)  # pyre-ignore[6]
217
218
219def get_decode_func(
220    device_index: int,
221    width: int = 224,
222    height: int = 224,
223) -> Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]]:
224    """Get a function to decode images from a list of paths.
225
226    Args:
227        device_index: The index of the target GPU device.
228        width: The width of the decoded image.
229        height: The height of the decoded image.
230
231    Returns:
232        Async function to decode images in to batch tensor of NCHW format
233        and labels of shape ``(batch_size, 1)``.
234    """
235    device: torch.device = torch.device(f"cuda:{device_index}")
236
237    filter_desc: str | None = spdl.io.get_video_filter_desc(
238        scale_width=256,
239        scale_height=256,
240        crop_width=width,
241        crop_height=height,
242        pix_fmt="rgb24",
243    )
244
245    def decode_images(items: list[tuple[str, int]]) -> tuple[Tensor, Tensor]:
246        paths = [item for item, _ in items]
247        labels = [[item] for _, item in items]
248        labels = torch.tensor(labels, dtype=torch.int64).to(device)
249        buffer = spdl.io.load_image_batch(
250            paths,
251            width=None,
252            height=None,
253            pix_fmt=None,
254            strict=True,
255            filter_desc=filter_desc,
256            device_config=spdl.io.cuda_config(
257                device_index=0,
258                allocator=(
259                    torch.cuda.caching_allocator_alloc,
260                    torch.cuda.caching_allocator_delete,
261                ),
262            ),
263        )
264        batch = spdl.io.to_torch(buffer)
265        batch = batch.permute((0, 3, 1, 2))
266        return batch, labels
267
268    return decode_images
269
270
271def _get_experimental_nvjpeg_decode_function(
272    device_index: int,
273    width: int = 224,
274    height: int = 224,
275) -> Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]]:
276    device: torch.device = torch.device(f"cuda:{device_index}")
277    device_config: spdl.io.CUDAConfig = spdl.io.cuda_config(
278        device_index=device_index,
279        allocator=(
280            torch.cuda.caching_allocator_alloc,
281            torch.cuda.caching_allocator_delete,
282        ),
283    )
284
285    def decode_images_nvjpeg(
286        items: list[tuple[str, int]],
287    ) -> tuple[Tensor, Tensor]:
288        paths = [item for item, _ in items]
289        labels = [[item] for _, item in items]
290        labels = torch.tensor(labels, dtype=torch.int64).to(device)
291        buffer = spdl.io.load_image_batch_nvjpeg(
292            paths,
293            device_config=device_config,
294            width=width,
295            height=height,
296            pix_fmt="rgb",
297            # strict=True,
298        )
299        batch = spdl.io.to_torch(buffer)
300        return batch, labels
301
302    return decode_images_nvjpeg
303
304
305def get_dataloader(
306    src: Iterator[tuple[str, int]],
307    batch_size: int,
308    decode_func: Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]],
309    buffer_size: int,
310    num_threads: int,
311) -> Iterator[tuple[Tensor, Tensor]]:
312    """Build the dataloader for the ImageNet classification task.
313
314    The dataloader uses the ``decode_func`` for decoding images concurrently and
315    send the resulting data to GPU.
316
317    Args:
318        src: The source of the data. See :py:func:`source`.
319        batch_size: The number of images in a batch.
320        decode_func: The function to decode images.
321        buffer_size: The size of the buffer for the dataloader sink
322        num_threads: The number of worker threads.
323
324    """
325    return DataLoader(  # pyre-ignore[7]
326        src,
327        batch_size=batch_size,
328        drop_last=True,
329        aggregator=decode_func,
330        buffer_size=buffer_size,
331        num_threads=num_threads,
332        timeout=20,
333    )
334
335
336def benchmark(
337    dataloader: Iterator[tuple[Tensor, Tensor]],
338    model: ModelBundle,
339    max_batches: float = float("nan"),
340) -> None:
341    """The main loop that measures the performance of dataloading and model inference.
342
343    Args:
344        loader: The dataloader to benchmark.
345        model: The model to benchmark.
346        max_batches: The number of batch before stopping.
347    """
348
349    _LG.info("Running inference.")
350    num_frames, num_correct_top1, num_correct_top5 = 0, 0, 0
351    t0 = time.monotonic()
352    try:
353        for i, (batch, labels) in enumerate(dataloader):
354            if i == 20:
355                t0 = time.monotonic()
356                num_frames, num_correct_top1, num_correct_top5 = 0, 0, 0
357
358            with (
359                torch.profiler.record_function(f"iter_{i}"),
360                spdl.io.utils.trace_event(f"iter_{i}"),
361            ):
362                top1, top5 = model(batch, labels)
363
364                num_frames += batch.shape[0]
365                num_correct_top1 += top1
366                num_correct_top5 += top5
367
368            if i + 1 >= max_batches:
369                break
370    finally:
371        elapsed = time.monotonic() - t0
372        if num_frames != 0:
373            num_correct_top1 = num_correct_top1.item()  # pyre-ignore[16]
374            # pyrefly: ignore [missing-attribute]
375            num_correct_top5 = num_correct_top5.item()
376            fps = num_frames / elapsed
377            _LG.info(f"FPS={fps:.2f} ({num_frames}/{elapsed:.2f})")
378            acc1 = 0 if num_frames == 0 else num_correct_top1 / num_frames
379            _LG.info(f"Accuracy (top1)={acc1:.2%} ({num_correct_top1}/{num_frames})")
380            acc5 = 0 if num_frames == 0 else num_correct_top5 / num_frames
381            _LG.info(f"Accuracy (top5)={acc5:.2%} ({num_correct_top5}/{num_frames})")
382
383
384def _get_dataloader(
385    args: Namespace, device_index: int
386) -> Iterator[tuple[Tensor, Tensor]]:
387    src = ImageNet(args.root_dir, split=args.split)
388
389    if args.use_nvjpeg:
390        decode_func = _get_experimental_nvjpeg_decode_function(device_index)
391    else:
392        decode_func = get_decode_func(device_index)
393
394    return get_dataloader(
395        src,  # pyre-ignore[6]
396        args.batch_size,
397        decode_func,
398        args.buffer_size,
399        args.num_threads,
400    )
401
402
403def entrypoint(args_: list[str] | None = None) -> None:
404    """CLI entrypoint. Run pipeline, transform and model and measure its performance."""
405
406    args = _parse_args(args_)
407    _init_logging(args.debug)
408    _LG.info(args)
409
410    device_index = 0
411    model = get_model(args.batch_size, device_index, args.compile, args.use_bf16)
412    dataloader = _get_dataloader(args, device_index)
413
414    trace_path = f"{args.trace}"
415    if args.use_nvjpeg:
416        trace_path = f"{trace_path}.nvjpeg"
417
418    with (
419        torch.no_grad(),
420        profile() if args.trace else contextlib.nullcontext() as prof,
421        spdl.io.utils.tracing(f"{trace_path}.pftrace", enable=args.trace is not None),
422    ):
423        benchmark(dataloader, model, args.max_batches)
424
425    if args.trace:
426        # pyrefly: ignore [missing-attribute]
427        prof.export_chrome_trace(f"{trace_path}.json")
428
429
430def _init_logging(debug: bool = False) -> None:
431    fmt = "%(asctime)s [%(filename)s:%(lineno)d] [%(levelname)s] %(message)s"
432    level = logging.DEBUG if debug else logging.INFO
433    logging.basicConfig(format=fmt, level=level)
434
435
436if __name__ == "__main__":
437    entrypoint()

API Reference

Functions

entrypoint(args_: list[str] | None = None) None[source]

CLI entrypoint. Run pipeline, transform and model and measure its performance.

benchmark(dataloader: Iterator[tuple[Tensor, Tensor]], model: ModelBundle, max_batches: float = nan) None[source]

The main loop that measures the performance of dataloading and model inference.

Parameters:
  • loader – The dataloader to benchmark.

  • model – The model to benchmark.

  • max_batches – The number of batch before stopping.

get_decode_func(device_index: int, width: int = 224, height: int = 224) Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]][source]

Get a function to decode images from a list of paths.

Parameters:
  • device_index – The index of the target GPU device.

  • width – The width of the decoded image.

  • height – The height of the decoded image.

Returns:

Async function to decode images in to batch tensor of NCHW format and labels of shape (batch_size, 1).

get_dataloader(src: Iterator[tuple[str, int]], batch_size: int, decode_func: Callable[[list[tuple[str, int]]], tuple[Tensor, Tensor]], buffer_size: int, num_threads: int) Iterator[tuple[Tensor, Tensor]][source]

Build the dataloader for the ImageNet classification task.

The dataloader uses the decode_func for decoding images concurrently and send the resulting data to GPU.

Parameters:
  • src – The source of the data. See source().

  • batch_size – The number of images in a batch.

  • decode_func – The function to decode images.

  • buffer_size – The size of the buffer for the dataloader sink

  • num_threads – The number of worker threads.

get_model(batch_size: int, device_index: int, compile: bool, use_bf16: bool, model_type: str = 'mobilenetv3_large_100') ModelBundle[source]

Build computation model, including transfor, model, and classification head.

Parameters:
  • batch_size – The batch size of the input.

  • device_index – The index of the target GPU device.

  • compile – Whether to compile the model.

  • use_bf16 – Whether to use bfloat16 for the model.

  • model_type – The type of the model. Passed to timm.create_model().

Returns:

The resulting computation model.

Classes

class ModelBundle[source]

Bundle the transform, model backbone, and classification head into a single module for a simple handling.

forward(images: Tensor, labels: Tensor) tuple[Tensor, Tensor][source]

Given a batch of images and labels, compute the top1, top5 accuracy.

Parameters:
  • images – A batch of images. The shape is (batch_size, 3, 224, 224).

  • labels – A batch of labels. The shape is (batch_size,).

Returns:

A tuple of top1 and top5 accuracy.

class Classification[source]
forward(x: Tensor, labels: Tensor) tuple[Tensor, Tensor][source]

Given a batch of features and labels, compute the top1 and top5 accuracy.

Parameters:
  • images – A batch of images. The shape is (batch_size, 3, 224, 224).

  • labels – A batch of labels. The shape is (batch_size,).

Returns:

A tuple of top1 and top5 accuracy.

class Preprocessing(mean: Tensor, std: Tensor)[source]

Perform pixel normalization and data type conversion.

Parameters:
  • mean – The mean value of the dataset.

  • std – The standard deviation of the dataset.

forward(x: Tensor) Tensor[source]

Normalize the given image batch.

Parameters:

x – The input image batch. Pixel values are expected to be in the range of [0, 255].

Returns:

The normalized image batch.