spdl.source.DistributedRandomSampler¶
- class DistributedRandomSampler(n: int, /, *, rank: int, world_size: int, num_draws: int | None = None, weights: list[float] | None = None, seed: int = 0)[source]¶
Bases:
objectSample dataset indices for the given distributed node while applying randomness.
This sampler ensures that each rank in a distributed training setup gets a disjoint subset of the data indices. When distributed training is not initialized, it returns all indices.
Example - Distributed sampling
The samplers from all ranks together cover the entire dataset.
>>> N = 9 >>> sampler = DistributedRandomSampler(N, rank=0, world_size=3) >>> list(sampler) [3, 2, 7] >>> sampler = DistributedRandomSampler(N, rank=1, world_size=3) >>> list(sampler) [6, 1, 4] >>> sampler = DistributedRandomSampler(N, rank=2, world_size=3) >>> list(sampler) [5, 8, 0]
Without calling
shuffle(), the sampler produces the same sequence on every iteration. To get a different order each epoch, callsampler.shuffle(seed=epoch)before iterating:Example
>>> sampler = DistributedRandomSampler(5, rank=0, world_size=1) >>> list(sampler) [4, 2, 0, 1, 3] >>> # If not shuffling, the second iteratoin generates the same sequence >>> list(sampler) [4, 2, 0, 1, 3] >>> sampler.shuffle(seed=1) >>> list(sampler) [3, 2, 4, 1, 5]
You can use
embed_shuffle()to shuffle automatically at each iteration.Example - Auto-shuffle
>>> sampler = embed_shuffle(DistributedRandomSampler(5, rank=0, world_size=1)) >>> list(sampler) [4, 2, 0, 1, 3] >>> list(sampler) [3, 2, 4, 1, 5] >>> list(sampler) [2, 1, 4, 3, 5]
This is especially useful when the sampler is iterated in a subprocess via
iterate_in_subprocess(), where callingshuffle()manually from the main process has no effect on the subprocess copy.Example - Running in a subprocess
When iterating the sampler in a subprocess, wrap it with
embed_shuffle()so that each epoch is automatically reshuffled inside the subprocess:from functools import partial from spdl.pipeline import iterate_in_subprocess from spdl.source import DistributedRandomSampler from spdl.source.utils import embed_shuffle sampler = DistributedRandomSampler(N, rank=rank, world_size=world_size) src = iterate_in_subprocess(embed_shuffle(sampler)) # Each epoch, ranks generate different disjoint set of indices for epoch in range(num_epochs): for idx in src: ...
This sampler supports wieghted sampling.
Example - Weighted sampling
By providing sampling weights, indices are drawn to follow the sampling weights. In this case, indices are sampled with replacement, thus they do not necessarily cover the entire dataset.
>>> N = 5 >>> weights = [0, 0, 1, 1, 1] >>> sampler = DistributedRandomSampler(5, rank=0, world_size=1, weights=weights) >>> list(sampler) [2, 4, 3, 3, 2]
With weighted sampling, you can sample indices more than the size of the dataset.
>>> sampler = DistributedRandomSampler( ... 5, rank=0, world_size=1, weights=weights, num_draws=10) >>> list(sampler) [2, 4, 3, 3, 2, 4, 3, 2, 4, 2]
- Parameters:
n – The size of the dataset.
rank – The rank in the distributed communication config. You can fetch the values with
torch.distributed.get_rank().world_size – The number of ranks in the distributed communication config. You can fetch the values with
torch.distributed.get_world_size().num_draws – Oprional The number of samples to draw at each iteration. When performing weighted sampling (
weightsis provided), this can be greater than the size of the dataset. Otherwise, it must be smaller than or equal to the size of the dataset.weights – Optional The sampling weight of each sample in the dataset. When provided, the length of the sequence must match the size of the dataset. (
size). Indices are drawn with replacement when weights are provided.seed – The seed value for generating the sequence.
Methods
shuffle(seed)Set the random seed for future iterations.
- __iter__() Iterator[int][source]¶
Iterate over the indices assigned to the current rank.
- Yields:
Individual indices assigned to the current rank.
- shuffle(seed: int) None[source]¶
Set the random seed for future iterations.
The resulting sequence depends only on the given
seedvalue and is not affected by any prior iterations or previous calls toshuffle(). Callingshuffle(seed=K)always produces the same sequence for a given sampler configuration, regardless of history.