spdl.source.DistributedDeterministicSampler¶
- class DistributedDeterministicSampler(n: int, /, *, rank: int, world_size: int, ddp_drop_last_distributed_round: bool = True)[source]¶
Bases:
objectSampler for distributed training that splits indices across multiple ranks.
This sampler ensures that each rank in a distributed training setup gets a disjoint subset of the data indices. When distributed training is not initialized, it returns all indices.
The iteration order is deterministic and always the same: indices are assigned in a round-robin fashion (
range(rank, N, world_size)). Every iteration produces the identical sequence. If you need a different order each epoch, useDistributedRandomSamplerinstead.When the dataset size is not divisible by
world_size, the final round is incomplete. Theddp_drop_last_distributed_roundargument controls how this leftover is handled: whenTrue(default), the incomplete final round is dropped so every rank receives the same number of indices; whenFalse, every sample is covered, so some ranks receive one more index than others.Dataset indices, N = 11, world_size = 4 ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐ │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ 9 │10 │ └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘ ddp_drop_last_distributed_round=True (default) rank 0: 0, 4 rank 1: 1, 5 rank 2: 2, 6 rank 3: 3, 7 (8, 9, 10 are dropped) ddp_drop_last_distributed_round=False rank 0: 0, 4, 8 rank 1: 1, 5, 9 rank 2: 2, 6, 10 rank 3: 3, 7 (all indices are covered)
- Parameters:
n – The size of the dataset.
rank – The rank in the distributed communication config. You can fetch the values with
torch.distributed.get_rank().world_size – The number of ranks in the distributed communication config. You can fetch the values with
torch.distributed.get_world_size().ddp_drop_last_distributed_round – If
True(default), drop the final incomplete distributed round so every rank receives the same number of indices. IfFalse, cover every sample exactly once across ranks; rank lengths may differ by at most one.
Added in version 0.5.0: The
ddp_drop_last_distributed_roundargument.