spdl.source.DistributedRandomSampler

class DistributedRandomSampler(n: int, /, *, rank: int, world_size: int, num_draws: int | None = None, weights: list[float] | None = None, seed: int = 0)[source]

Bases: object

Sample dataset indices for the given distributed node while applying randomness.

This sampler ensures that each rank in a distributed training setup gets a disjoint subset of the data indices. When distributed training is not initialized, it returns all indices.

Example - Distributed sampling

The samplers from all ranks together cover the entire dataset.

>>> N = 9
>>> sampler = DistributedRandomSampler(N, rank=0, world_size=3)
>>> list(sampler)
[3, 2, 7]
>>> sampler = DistributedRandomSampler(N, rank=1, world_size=3)
>>> list(sampler)
[6, 1, 4]
>>> sampler = DistributedRandomSampler(N, rank=2, world_size=3)
>>> list(sampler)
[5, 8, 0]

Without calling shuffle(), the sampler produces the same sequence on every iteration. To get a different order each epoch, call sampler.shuffle(seed=epoch) before iterating:

Example

>>> sampler = DistributedRandomSampler(5, rank=0, world_size=1)
>>> list(sampler)
[4, 2, 0, 1, 3]
>>> # If not shuffling, the second iteratoin generates the same sequence
>>> list(sampler)
[4, 2, 0, 1, 3]
>>> sampler.shuffle(seed=1)
>>> list(sampler)
[3, 2, 4, 1, 5]

You can use embed_shuffle() to shuffle automatically at each iteration.

Example - Auto-shuffle

>>> sampler = embed_shuffle(DistributedRandomSampler(5, rank=0, world_size=1))
>>> list(sampler)
[4, 2, 0, 1, 3]
>>> list(sampler)
[3, 2, 4, 1, 5]
>>> list(sampler)
[2, 1, 4, 3, 5]

This is especially useful when the sampler is iterated in a subprocess via iterate_in_subprocess(), where calling shuffle() manually from the main process has no effect on the subprocess copy.

Example - Running in a subprocess

When iterating the sampler in a subprocess, wrap it with embed_shuffle() so that each epoch is automatically reshuffled inside the subprocess:

from functools import partial
from spdl.pipeline import iterate_in_subprocess
from spdl.source import DistributedRandomSampler
from spdl.source.utils import embed_shuffle

sampler = DistributedRandomSampler(N, rank=rank, world_size=world_size)
src = iterate_in_subprocess(embed_shuffle(sampler))

# Each epoch, ranks generate different disjoint set of indices
for epoch in range(num_epochs):
    for idx in src:
        ...

This sampler supports wieghted sampling.

Example - Weighted sampling

By providing sampling weights, indices are drawn to follow the sampling weights. In this case, indices are sampled with replacement, thus they do not necessarily cover the entire dataset.

>>> N = 5
>>> weights = [0, 0, 1, 1, 1]
>>> sampler = DistributedRandomSampler(5, rank=0, world_size=1, weights=weights)
>>> list(sampler)
[2, 4, 3, 3, 2]

With weighted sampling, you can sample indices more than the size of the dataset.

>>> sampler = DistributedRandomSampler(
...     5, rank=0, world_size=1, weights=weights, num_draws=10)
>>> list(sampler)
[2, 4, 3, 3, 2, 4, 3, 2, 4, 2]
Parameters:
  • n – The size of the dataset.

  • rank – The rank in the distributed communication config. You can fetch the values with torch.distributed.get_rank().

  • world_size – The number of ranks in the distributed communication config. You can fetch the values with torch.distributed.get_world_size().

  • num_drawsOprional The number of samples to draw at each iteration. When performing weighted sampling (weights is provided), this can be greater than the size of the dataset. Otherwise, it must be smaller than or equal to the size of the dataset.

  • weightsOptional The sampling weight of each sample in the dataset. When provided, the length of the sequence must match the size of the dataset. (size). Indices are drawn with replacement when weights are provided.

  • seed – The seed value for generating the sequence.

Methods

shuffle(seed)

Set the random seed for future iterations.

__iter__() Iterator[int][source]

Iterate over the indices assigned to the current rank.

Yields:

Individual indices assigned to the current rank.

__len__() int[source]

The number of indices returned by this sampler.

shuffle(seed: int) None[source]

Set the random seed for future iterations.

The resulting sequence depends only on the given seed value and is not affected by any prior iterations or previous calls to shuffle(). Calling shuffle(seed=K) always produces the same sequence for a given sampler configuration, regardless of history.