neuralbench.utils.make_weighted_sampler¶
- neuralbench.utils.make_weighted_sampler(dataset: SegmentDataset, logger: Logger, generator: Generator | None = None) WeightedRandomSampler[source][source]¶
Create a weighted random sampler for the given dataset to handle class imbalance.
- Parameters:
dataset – Training dataset whose targets drive the class-weight computation.
logger – Logger forwarded to
compute_class_weights_from_dataset().generator – Optional
torch.Generatorused by the returned sampler. When set, successive iterations of the sampler draw from this generator instead of the globaltorchRNG, so the sampling sequence is determined solely by the generator’s seed.