Video dataloading

This example uses SPDL to decode and batch video frames, then send them to GPU.

The structure of the pipeline is identical to that of image_dataloading.

Basic Usage

Running this example requires a dataset consists of videos.

For example, to run this example with Kinetics dataset.

  1. Download Kinetics dataset. https://github.com/cvdfoundation/kinetics-dataset provides scripts to facilitate this.

  2. Create a list containing the downloaded videos.

    cd /data/users/moto/kinetics-dataset/k400/
    find train -name '*.mp4' > ~/imagenet.train.flist
    
  3. Run the script.

    python examples/video_dataloading.py
      --input-flist ~/kinetics400.train.flist
      --prefix /data/users/moto/kinetics-dataset/k400/
      --num-threads 8
    

Using GPU video decoder

When SPDL is built with NVDEC integration enabled, and the GPUs support NVDEC, providing --nvdec option switches the video decoder to NVDEC, using spdl.io.decode_packets_nvdec(). When using this option, adjust the number of threads (the number of concurrent decoding) to accommodate the number of hardware video decoder available on GPUs. For the details, please refer to https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new

Note

This example decodes videos from the beginning to the end, so using NVDEC speeds up the whole decoding speed. But in cases where framees are sampled, CPU decoding with higher concurrency often yields higher throughput.

Source

Source

Click here to see the source.
  1# Copyright (c) Meta Platforms, Inc. and affiliates.
  2# All rights reserved.
  3#
  4# This source code is licensed under the BSD-style license found in the
  5# LICENSE file in the root directory of this source tree.
  6
  7"""This example uses SPDL to decode and batch video frames, then send them to GPU.
  8
  9The structure of the pipeline is identical to that of
 10:py:mod:`image_dataloading`.
 11
 12Basic Usage
 13-----------
 14
 15Running this example requires a dataset consists of videos.
 16
 17For example, to run this example with Kinetics dataset.
 18
 191. Download Kinetics dataset.
 20   https://github.com/cvdfoundation/kinetics-dataset provides scripts to facilitate this.
 212. Create a list containing the downloaded videos.
 22
 23   .. code-block::
 24
 25      cd /data/users/moto/kinetics-dataset/k400/
 26      find train -name '*.mp4' > ~/imagenet.train.flist
 27
 283. Run the script.
 29
 30   .. code-block:: shell
 31
 32      python examples/video_dataloading.py
 33        --input-flist ~/kinetics400.train.flist
 34        --prefix /data/users/moto/kinetics-dataset/k400/
 35        --num-threads 8
 36
 37Using GPU video decoder
 38-----------------------
 39
 40When SPDL is built with NVDEC integration enabled, and the GPUs support NVDEC,
 41providing ``--nvdec`` option switches the video decoder to NVDEC, using
 42:py:func:`spdl.io.decode_packets_nvdec`. When using this option, adjust the
 43number of threads (the number of concurrent decoding) to accommodate
 44the number of hardware video decoder available on GPUs.
 45For the details, please refer to https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new
 46
 47.. note::
 48
 49   This example decodes videos from the beginning to the end, so using NVDEC
 50   speeds up the whole decoding speed. But in cases where framees are sampled,
 51   CPU decoding with higher concurrency often yields higher throughput.
 52"""
 53
 54# pyre-strict
 55
 56import argparse
 57import logging
 58import signal
 59import time
 60from argparse import Namespace
 61from collections.abc import Callable, Iterable
 62from dataclasses import dataclass
 63from pathlib import Path
 64from threading import Event
 65
 66import spdl.io
 67import spdl.io.utils
 68import torch
 69from spdl.pipeline import Pipeline, PipelineBuilder
 70from torch import Tensor
 71
 72_LG: logging.Logger = logging.getLogger(__name__)
 73
 74__all__ = [
 75    "entrypoint",
 76    "worker_entrypoint",
 77    "benchmark",
 78    "source",
 79    "decode_video",
 80    "decode_video_nvdec",
 81    "get_pipeline",
 82    "PerfResult",
 83]
 84
 85
 86def _parse_args(args: list[str]) -> Namespace:
 87    parser = argparse.ArgumentParser(
 88        description=__doc__,
 89    )
 90    parser.add_argument("--debug", action="store_true")
 91    parser.add_argument("--input-flist", type=Path, required=True)
 92    parser.add_argument("--max-samples", type=int, default=float("inf"))
 93    parser.add_argument("--prefix", default="")
 94    parser.add_argument("--trace", type=Path)
 95    parser.add_argument("--queue-size", type=int, default=16)
 96    parser.add_argument("--num-threads", type=int, required=True)
 97    parser.add_argument("--worker-id", type=int, required=True)
 98    parser.add_argument("--num-workers", type=int, required=True)
 99    parser.add_argument("--nvdec", action="store_true")
100    ns = parser.parse_args(args)
101    if ns.trace:
102        ns.max_samples = 320
103    return ns
104
105
106def source(
107    input_flist: str,
108    prefix: str,
109    max_samples: int,
110    split_size: int = 1,
111    split_id: int = 0,
112) -> Iterable[str]:
113    """Iterate a file containing a list of paths, while optionally skipping some.
114
115    Args:
116        input_flist: A file contains list of video paths.
117        prefix: Prepended to the paths in the list.
118        max_samples: The maximum number of items to yield.
119        split_size: Split the paths in to this number of subsets.
120        split_id: The index of this split. Paths at ``line_number % split_size == split_id`` are returned.
121
122    Yields:
123        The paths of the specified split.
124    """
125    with open(input_flist, "r") as f:
126        num_yielded = 0
127        for i, line in enumerate(f):
128            if i % split_size != split_id:
129                continue
130            if line := line.strip():
131                yield prefix + line
132
133                if (num_yielded := num_yielded + 1) >= max_samples:
134                    return
135
136
137def decode_video(
138    src: str | bytes,
139    width: int,
140    height: int,
141    device_index: int,
142) -> Tensor:
143    """Decode video and send decoded frames to GPU.
144
145    Args:
146        src: Data source. Passed to :py:func:`spdl.io.demux_video`.
147        width, height: The target resolution.
148        device_index: The index of the target GPU.
149
150    Returns:
151        A GPU tensor represents decoded video frames.
152        The dtype is uint8, the shape is ``[N, C, H, W]``, where ``N`` is the number
153        of frames in the video, ``C`` is RGB channels.
154    """
155    packets = spdl.io.demux_video(src)
156    frames = spdl.io.decode_packets(
157        packets,
158        filter_desc=spdl.io.get_filter_desc(
159            packets,
160            scale_width=width,
161            scale_height=height,
162            pix_fmt="rgb24",
163        ),
164    )
165    buffer = spdl.io.convert_frames(frames)
166    buffer = spdl.io.transfer_buffer(
167        buffer,
168        device_config=spdl.io.cuda_config(
169            device_index=device_index,
170            allocator=(
171                torch.cuda.caching_allocator_alloc,
172                torch.cuda.caching_allocator_delete,
173            ),
174        ),
175    )
176    return spdl.io.to_torch(buffer).permute(0, 2, 3, 1)
177
178
179def decode_video_nvdec(
180    src: str,
181    device_index: int,
182    width: int,
183    height: int,
184) -> Tensor:
185    """Decode video using NVDEC.
186
187    Args:
188        src: Data source. Passed to :py:func:`spdl.io.demux_video`.
189        device_index: The index of the target GPU.
190        width, height: The target resolution.
191
192    Returns:
193        A GPU tensor represents decoded video frames.
194        The dtype is uint8, the shape is ``[N, C, H, W]``, where ``N`` is the number
195        of frames in the video, ``C`` is RGB channels.
196    """
197    packets = spdl.io.demux_video(src)
198    buffer = spdl.io.decode_packets_nvdec(
199        packets,
200        device_config=spdl.io.cuda_config(
201            device_index=device_index,
202            allocator=(
203                torch.cuda.caching_allocator_alloc,
204                torch.cuda.caching_allocator_delete,
205            ),
206        ),
207        scale_width=width,
208        scale_height=height,
209        pix_fmt="rgb",
210    )
211    return spdl.io.to_torch(buffer)[..., :3].permute(0, 2, 3, 1)
212
213
214def _get_decode_fn(
215    device_index: int, use_nvdec: bool, width: int = 222, height: int = 222
216) -> Callable[[str], Tensor]:
217    if use_nvdec:
218
219        def _decode_func(src: str) -> Tensor:
220            return decode_video_nvdec(src, device_index, width, height)
221
222    else:
223
224        def _decode_func(src: str) -> Tensor:
225            return decode_video(src, width, height, device_index)
226
227    return _decode_func
228
229
230def get_pipeline(
231    src: Iterable[str],
232    decode_fn: Callable[[str], Tensor],
233    decode_concurrency: int,
234    num_threads: int,
235    buffer_size: int = 3,
236) -> Pipeline:
237    """Construct the video loading pipeline.
238
239    Args:
240        src: Pipeline source. Generator that yields image paths. See :py:func:`source`.
241        decode_fn: Function that decode the given image and send the decoded frames to GPU.
242        decode_concurrency: The maximum number of decoding scheduled concurrently.
243        num_threads: The number of threads in the pipeline.
244        buffer_size: The size of buffer for the resulting batch image Tensor.
245    """
246    return (
247        PipelineBuilder()
248        .add_source(src)
249        .pipe(decode_fn, concurrency=decode_concurrency)
250        .add_sink(buffer_size)
251        .build(num_threads=num_threads, report_stats_interval=15)
252    )
253
254
255def _get_pipeline(args: Namespace) -> Pipeline:
256    src = source(
257        input_flist=args.input_flist,
258        prefix=args.prefix,
259        max_samples=args.max_samples,
260        split_id=args.worker_id,
261        split_size=args.num_workers,
262    )
263
264    decode_fn = _get_decode_fn(args.worker_id, args.nvdec)
265    pipeline = get_pipeline(
266        src,
267        decode_fn,
268        decode_concurrency=args.num_threads,
269        num_threads=args.num_threads + 3,
270        buffer_size=args.queue_size,
271    )
272    print(pipeline)
273    return pipeline
274
275
276@dataclass
277class PerfResult:
278    """Used to report the worker performance to the main process."""
279
280    elapsed: float
281    """The time it took to process all the inputs."""
282
283    num_batches: int
284    """The number of batches processed."""
285
286    num_frames: int
287    """The number of frames processed."""
288
289
290def benchmark(
291    dataloader: Iterable[Tensor],
292    stop_requested: Event,
293) -> PerfResult:
294    """The main loop that measures the performance of dataloading.
295
296    Args:
297        dataloader: The dataloader to benchmark.
298        stop_requested: Used to interrupt the benchmark loop.
299
300    Returns:
301        The performance result.
302    """
303    t0 = time.monotonic()
304    num_frames = num_batches = 0
305    try:
306        for batches in dataloader:
307            for batch in batches:
308                num_frames += batch.shape[0]
309                num_batches += 1
310
311            if stop_requested.is_set():
312                break
313
314    finally:
315        elapsed = time.monotonic() - t0
316        fps = num_frames / elapsed
317        _LG.info(f"FPS={fps:.2f} ({num_frames} / {elapsed:.2f}), (Done {num_frames})")
318
319    return PerfResult(elapsed, num_batches, num_frames)
320
321
322def worker_entrypoint(args_: list[str]) -> PerfResult:
323    """Entrypoint for worker process. Load images to a GPU and measure its performance.
324
325    It builds a Pipeline object using :py:func:`get_pipeline` function and run it with
326    :py:func:`benchmark` function.
327    """
328    args = _parse_args(args_)
329    _init(args.debug, args.worker_id)
330
331    _LG.info(args)
332
333    pipeline = _get_pipeline(args)
334
335    device = torch.device(f"cuda:{args.worker_id}")
336
337    ev: Event = Event()
338
339    def handler_stop_signals(_signum, _frame) -> None:
340        ev.set()
341
342    signal.signal(signal.SIGTERM, handler_stop_signals)
343
344    # Warm up
345    torch.zeros([1, 1], device=device)
346
347    trace_path = f"{args.trace}.{args.worker_id}"
348    with (
349        pipeline.auto_stop(),
350        spdl.io.utils.tracing(trace_path, enable=args.trace is not None),
351    ):
352        return benchmark(pipeline.get_iterator(), ev)
353
354
355def _init_logging(debug: bool = False, worker_id: int | None = None) -> None:
356    fmt = "%(asctime)s [%(levelname)s] %(message)s"
357    if worker_id is not None:
358        fmt = f"[{worker_id}:%(thread)d] {fmt}"
359    level = logging.DEBUG if debug else logging.INFO
360    logging.basicConfig(format=fmt, level=level)
361
362
363def _init(debug: bool, worker_id: int) -> None:
364    _init_logging(debug, worker_id)
365
366
367def _parse_process_args(args: list[str] | None) -> tuple[Namespace, list[str]]:
368    parser = argparse.ArgumentParser(
369        description=__doc__,
370    )
371    parser.add_argument("--num-workers", type=int, default=8)
372    return parser.parse_known_args(args)
373
374
375def entrypoint(args: list[str] | None = None) -> None:
376    """CLI entrypoint. Launch the worker processes, each of which load videos and send them to GPU."""
377    ns, args = _parse_process_args(args)
378
379    args_set = [
380        [*args, f"--worker-id={i}", f"--num-workers={ns.num_workers}"]
381        for i in range(ns.num_workers)
382    ]
383
384    from multiprocessing import Pool
385
386    with Pool(processes=ns.num_workers) as pool:
387        _init_logging()
388        _LG.info("Spawned: %d workers", ns.num_workers)
389
390        vals = pool.map(worker_entrypoint, args_set)
391
392    ave_time = sum(v.elapsed for v in vals) / len(vals)
393    total_frames = sum(v.num_frames for v in vals)
394    total_batches = sum(v.num_batches for v in vals)
395
396    _LG.info(f"{ave_time=:.2f}, {total_frames=}, {total_batches=}")
397
398    FPS = total_frames / ave_time
399    BPS = total_batches / ave_time
400    _LG.info(f"Aggregated {FPS=:.2f}, {BPS=:.2f}")
401
402
403if __name__ == "__main__":
404    entrypoint()

Functions

Functions

entrypoint(args: list[str] | None = None) None[source]

CLI entrypoint. Launch the worker processes, each of which load videos and send them to GPU.

worker_entrypoint(args_: list[str]) PerfResult[source]

Entrypoint for worker process. Load images to a GPU and measure its performance.

It builds a Pipeline object using get_pipeline() function and run it with benchmark() function.

benchmark(dataloader: Iterable[Tensor], stop_requested: Event) PerfResult[source]

The main loop that measures the performance of dataloading.

Parameters:
  • dataloader – The dataloader to benchmark.

  • stop_requested – Used to interrupt the benchmark loop.

Returns:

The performance result.

source(input_flist: str, prefix: str, max_samples: int, split_size: int = 1, split_id: int = 0) Iterable[str][source]

Iterate a file containing a list of paths, while optionally skipping some.

Parameters:
  • input_flist – A file contains list of video paths.

  • prefix – Prepended to the paths in the list.

  • max_samples – The maximum number of items to yield.

  • split_size – Split the paths in to this number of subsets.

  • split_id – The index of this split. Paths at line_number % split_size == split_id are returned.

Yields:

The paths of the specified split.

decode_video(src: str | bytes, width: int, height: int, device_index: int) Tensor[source]

Decode video and send decoded frames to GPU.

Parameters:
  • src – Data source. Passed to spdl.io.demux_video().

  • width – The target resolution.

  • height – The target resolution.

  • device_index – The index of the target GPU.

Returns:

A GPU tensor represents decoded video frames. The dtype is uint8, the shape is [N, C, H, W], where N is the number of frames in the video, C is RGB channels.

decode_video_nvdec(src: str, device_index: int, width: int, height: int) Tensor[source]

Decode video using NVDEC.

Parameters:
  • src – Data source. Passed to spdl.io.demux_video().

  • device_index – The index of the target GPU.

  • width – The target resolution.

  • height – The target resolution.

Returns:

A GPU tensor represents decoded video frames. The dtype is uint8, the shape is [N, C, H, W], where N is the number of frames in the video, C is RGB channels.

get_pipeline(src: Iterable[str], decode_fn: Callable[[str], Tensor], decode_concurrency: int, num_threads: int, buffer_size: int = 3) Pipeline[source]

Construct the video loading pipeline.

Parameters:
  • src – Pipeline source. Generator that yields image paths. See source().

  • decode_fn – Function that decode the given image and send the decoded frames to GPU.

  • decode_concurrency – The maximum number of decoding scheduled concurrently.

  • num_threads – The number of threads in the pipeline.

  • buffer_size – The size of buffer for the resulting batch image Tensor.

Classes

Classes

class PerfResult(elapsed: float, num_batches: int, num_frames: int)[source]

Used to report the worker performance to the main process.

elapsed: float

The time it took to process all the inputs.

num_batches: int

The number of batches processed.

num_frames: int

The number of frames processed.