spdl.source.DistributedRandomSampler

class DistributedRandomSampler(n: int, /, *, rank: int, world_size: int, num_draws: int | None = None, weights: list[float] | None = None, seed: int = 0)[source]

Bases: object

Sample dataset indices for the given distributed node while applying randomness.

This sampler ensures that each rank in a distributed training setup gets a disjoint subset of the data indices. When distributed training is not initialized, it returns all indices.

This sampler can apply two randomness; shuffling and wieghted sampling.

Example

>>> sampler = DistributedRandomSampler(5, rank=0, world_size=1)
>>> list(sampler)
[4, 2, 0, 1, 3]
>>> # If not shuffling, the second iteratoin generates the same sequence
>>> list(sampler)
[4, 2, 0, 1, 3]
>>> sampler.shuffle(seed=1)
>>> list(sampler)
[3, 2, 4, 1, 5]

You can use embed_shuffle() to shuffle automatically at each iteration.

>>> sampler = embed_shuffle(DistributedRandomSampler(5, rank=0, world_size=1))
>>> list(sampler)
[4, 2, 0, 1, 3]
>>> list(sampler)
[3, 2, 4, 1, 5]
>>> list(sampler)
[2, 1, 4, 3, 5]

Example - Distributed sampling

The samplers from all ranks together cover the entire dataset.

>>> N = 9
>>> sampler = DistributedRandomSampler(N, rank=0, world_size=3)
>>> list(sampler)
[3, 2, 7]
>>> sampler = DistributedRandomSampler(N, rank=1, world_size=3)
>>> list(sampler)
[6, 1, 4]
>>> sampler = DistributedRandomSampler(N, rank=2, world_size=3)
>>> list(sampler)
[5, 8, 0]

Example - Weighted sampling

By providing sampling weights, indices are drawn to follow the sampling weights. In this case, indices are sampled with replacement, thus they do not necessarily cover the entire dataset.

>>> N = 5
>>> weights = [0, 0, 1, 1, 1]
>>> sampler = DistributedRandomSampler(5, rank=0, world_size=1, weights=weights)
>>> list(sampler)
[2, 4, 3, 3, 2]

With weighted sampling, you can sample indices more than the size of the dataset.

>>> sampler = DistributedRandomSampler(
...     5, rank=0, world_size=1, weights=weights, num_draws=10)
>>> list(sampler)
[2, 4, 3, 3, 2, 4, 3, 2, 4, 2]
Parameters:
  • n – The size of the dataset.

  • rank – The rank in the distributed communication config. You can fetch the values with torch.distributed.get_rank().

  • world_size – The number of ranks in the distributed communication config. You can fetch the values with torch.distributed.get_world_size().

  • num_draws – The number of samples to draw at each iteration. If peforming weighted sampling, (deterministic=False and weights is provided) then it can be greater than the size of dataset. Otherwise, it must be smaller than or equal to the size of dataset.

  • weightsOptional The sampling weight of each sample in the dataset. When provided, the length of the sequence must match the size of the dataset. (size). This option is ignored if deterministic=True.

  • seed – The seed value for generating the sequence.

Methods

shuffle(seed)

Set the random seed for future iterations.

__iter__() Iterator[int][source]

Iterate over the indices assigned to the current rank.

Yields:

Individual indices assigned to the current rank.

__len__() int[source]

The number of indices returned by this sampler.

shuffle(seed: int) None[source]

Set the random seed for future iterations.