spdl.source.DistributedDeterministicSampler¶
- class DistributedDeterministicSampler(n: int, /, *, rank: int, world_size: int)[source]¶
Bases:
object
Sampler for distributed training that splits indices across multiple ranks.
This sampler ensures that each rank in a distributed training setup gets a disjoint subset of the data indices. When distributed training is not initialized, it returns all indices.
- Parameters:
n – The size of the dataset.
rank – The rank in the distributed communication config. You can fetch the values with
torch.distributed.get_rank()
.world_size – The number of ranks in the distributed communication config. You can fetch the values with
torch.distributed.get_world_size()
.