spdl.source.DistributedRandomSampler¶
- class DistributedRandomSampler(n: int, /, *, rank: int, world_size: int, num_draws: int | None = None, weights: list[float] | None = None, seed: int = 0)[source]¶
Bases:
object
Sample dataset indices for the given distributed node while applying randomness.
This sampler ensures that each rank in a distributed training setup gets a disjoint subset of the data indices. When distributed training is not initialized, it returns all indices.
This sampler can apply two randomness; shuffling and wieghted sampling.
Example
>>> sampler = DistributedRandomSampler(5, rank=0, world_size=1) >>> list(sampler) [4, 2, 0, 1, 3] >>> # If not shuffling, the second iteratoin generates the same sequence >>> list(sampler) [4, 2, 0, 1, 3] >>> sampler.shuffle(seed=1) >>> list(sampler) [3, 2, 4, 1, 5]
You can use
embed_shuffle()
to shuffle automatically at each iteration.>>> sampler = embed_shuffle(DistributedRandomSampler(5, rank=0, world_size=1)) >>> list(sampler) [4, 2, 0, 1, 3] >>> list(sampler) [3, 2, 4, 1, 5] >>> list(sampler) [2, 1, 4, 3, 5]
Example - Distributed sampling
The samplers from all ranks together cover the entire dataset.
>>> N = 9 >>> sampler = DistributedRandomSampler(N, rank=0, world_size=3) >>> list(sampler) [3, 2, 7] >>> sampler = DistributedRandomSampler(N, rank=1, world_size=3) >>> list(sampler) [6, 1, 4] >>> sampler = DistributedRandomSampler(N, rank=2, world_size=3) >>> list(sampler) [5, 8, 0]
Example - Weighted sampling
By providing sampling weights, indices are drawn to follow the sampling weights. In this case, indices are sampled with replacement, thus they do not necessarily cover the entire dataset.
>>> N = 5 >>> weights = [0, 0, 1, 1, 1] >>> sampler = DistributedRandomSampler(5, rank=0, world_size=1, weights=weights) >>> list(sampler) [2, 4, 3, 3, 2]
With weighted sampling, you can sample indices more than the size of the dataset.
>>> sampler = DistributedRandomSampler( ... 5, rank=0, world_size=1, weights=weights, num_draws=10) >>> list(sampler) [2, 4, 3, 3, 2, 4, 3, 2, 4, 2]
- Parameters:
n – The size of the dataset.
rank – The rank in the distributed communication config. You can fetch the values with
torch.distributed.get_rank()
.world_size – The number of ranks in the distributed communication config. You can fetch the values with
torch.distributed.get_world_size()
.num_draws – The number of samples to draw at each iteration. If peforming weighted sampling, (
deterministic=False
andweights
is provided) then it can be greater than the size of dataset. Otherwise, it must be smaller than or equal to the size of dataset.weights – Optional The sampling weight of each sample in the dataset. When provided, the length of the sequence must match the size of the dataset. (
size
). This option is ignored ifdeterministic=True
.seed – The seed value for generating the sequence.
Methods
shuffle
(seed)Set the random seed for future iterations.