neuralbench.utils.make_weighted_sampler

neuralbench.utils.make_weighted_sampler(dataset: SegmentDataset, logger: Logger, generator: Generator | None = None) WeightedRandomSampler[source][source]

Create a weighted random sampler for the given dataset to handle class imbalance.

Parameters:
  • dataset – Training dataset whose targets drive the class-weight computation.

  • logger – Logger forwarded to compute_class_weights_from_dataset().

  • generator – Optional torch.Generator used by the returned sampler. When set, successive iterations of the sampler draw from this generator instead of the global torch RNG, so the sampling sequence is determined solely by the generator’s seed.