Source code for neuralbench.data

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import typing as tp

import numpy as np
import torch
from pydantic import field_validator
from torch.utils.data import DataLoader
from tqdm import tqdm

import neuralset as ns

from .extractors import SleepOnsetTargetExtractor  # noqa: F401
from .transforms import (  # noqa: F401
    AddDefaultEvents,
    AddSleepOnsetTargets,
    CropSleepRecordings,
    CropTimelines,
    OffsetEvents,
    PredefinedSplit,
    ShuffleTrainingLabels,
    SimilaritySplit,
    SklearnSplit,
    TextPreprocessor,
)
from .utils import make_regression_bin_sampler, make_weighted_sampler, seed_worker

LOGGER = logging.getLogger(__name__)


class BaseSampler(ns.base.NamedModel):
    """Base configuration for a train-time DataLoader sampler.

    Subclasses define how training examples are drawn from a
    :class:`~neuralset.dataloader.SegmentDataset` -- typically to counteract
    class or target imbalance -- by overriding :meth:`build`.  Concrete
    samplers are selected from YAML via the ``name`` discriminator, e.g.
    ``sampler: {name: ClassificationSampler}``.

    Only the training DataLoader uses the configured sampler; the val and
    test DataLoaders always iterate the dataset in order.
    """

    def build(
        self,
        dataset: ns.dataloader.SegmentDataset,
        generator: torch.Generator | None = None,
    ) -> torch.utils.data.Sampler:
        """Build the sampler for the training ``dataset``.

        Parameters
        ----------
        dataset
            Training segment dataset; subclasses typically materialise its
            targets to compute per-sample weights.
        generator
            Optional :class:`torch.Generator` consumed by the returned
            sampler.  When set, the sampling sequence depends only on the
            generator's seed -- not on the global ``torch`` RNG -- which is
            what :class:`~neuralbench.data.Data` relies on for reproducible
            training runs.
        """
        raise NotImplementedError


class ClassificationSampler(BaseSampler):
    """Inverse-frequency weighted sampler for class-imbalanced classification.

    Computes balanced class weights from the training targets and yields a
    :class:`torch.utils.data.WeightedRandomSampler` whose per-sample weights
    equal the inverse class frequency.  In expectation, every class
    contributes the same total mass to each training epoch.

    Use for multi-class classification with skewed label distributions
    (e.g. seizure detection, sleep arousal).  Multilabel targets are not
    yet supported -- see the TODO in :func:`make_weighted_sampler`.
    """

    def build(
        self,
        dataset: ns.dataloader.SegmentDataset,
        generator: torch.Generator | None = None,
    ) -> torch.utils.data.Sampler:
        return make_weighted_sampler(dataset, logger=LOGGER, generator=generator)


class RegressionBinSampler(BaseSampler):
    """Inverse-frequency weighted sampler stratified by binned regression targets.

    The regression counterpart of :class:`ClassificationSampler`.  Targets
    are bucketised by ``bin_edges`` and assigned inverse-frequency sampling
    weights so that, in expectation, every populated bin contributes the
    same total mass to each training epoch.

    Bin semantics match :class:`~neuralbench.metrics.BinnedMAE`:

    * Bin ``i`` covers ``[bin_edges[i], bin_edges[i + 1])`` for
      ``i < n_bins - 1``; the top bin is closed on the right so targets
      exactly at ``bin_edges[-1]`` are counted in the top bin.
    * Targets outside ``[bin_edges[0], bin_edges[-1]]`` receive **zero
      weight** -- they are excluded from sampling and do not contribute to
      any bin's count.

    Parameters
    ----------
    bin_edges
        Strictly increasing sequence of length ``>= 2`` defining the bins,
        e.g. ``[0.0, 40.0, 90.0, 300.0, 600.0]`` for the sleep-onset task.
    """

    bin_edges: list[float]

    @field_validator("bin_edges")
    @classmethod
    def _validate_bin_edges(cls, value: list[float]) -> list[float]:
        if len(value) < 2:
            raise ValueError(
                f"bin_edges must have length >= 2 to define at least one bin, "
                f"got {value}."
            )
        if any(value[i + 1] <= value[i] for i in range(len(value) - 1)):
            raise ValueError(f"bin_edges must be strictly increasing, got {value}.")
        return value

    def build(
        self,
        dataset: ns.dataloader.SegmentDataset,
        generator: torch.Generator | None = None,
    ) -> torch.utils.data.Sampler:
        return make_regression_bin_sampler(
            dataset, bin_edges=self.bin_edges, logger=LOGGER, generator=generator
        )


[docs] class Data(ns.BaseModel): """Create dataloaders for brain-modeling experiments.""" study: ns.Step neuro: ns.extractors.BaseExtractor target: ns.extractors.BaseExtractor channel_positions: ns.extractors.ChannelPositions # Segments trigger_event_type: str | list[str] start: float = -0.5 duration: float | None = 3 stride: float | None = None stride_drop_incomplete: bool = True # Dataloaders sampler: BaseSampler | None = None batch_size: int = 64 num_workers: int = 0 drop_last: bool = False pin_memory: bool = True persistent_workers: bool = True prefetch_factor: int | None = None seed: int | None = None # Others summary_columns: list[str] = [] _subject_id: ns.extractors.LabelEncoder | None = None def model_post_init(self, __context): super().model_post_init(__context) self._subject_id = ns.extractors.LabelEncoder( event_types=self.neuro.event_types, event_field="subject", return_one_hot=False, )
[docs] def prepare(self) -> dict[str, DataLoader]: """Load events, build extractors, segment data and return train/val/test DataLoaders. Returns ------- dict with keys ``"train"``, ``"val"``, ``"test"`` mapping to :class:`~torch.utils.data.DataLoader` instances. """ events = self.study.run() if "split" not in events.columns: LOGGER.error( "No `split` column found in events. Make sure splits are defined in the study, " "or use an events transform (`neuralset.events.transforms`) to add them." ) summary_columns = ["index", "subject"] + self.summary_columns + ["timeline"] summary_df = ( events.reset_index() .groupby(["study", "split", "type"], dropna=False)[summary_columns] .nunique() ) LOGGER.info("Dataset summary:\n%s", summary_df.to_string()) extractors = { "neuro": self.neuro, "target": self.target, "subject_id": self._subject_id, } if isinstance(self.neuro, ns.extractors.MneRaw): # Prepare the neuro extractor first because the channel positions depend on it self.neuro.prepare(events) channels = self.neuro._channels assert channels is not None channel_positions = self.channel_positions.build(self.neuro) LOGGER.info( f"Found {len(channels)} different channels: {list(channels.keys())}" ) extractors["channel_positions"] = channel_positions trigger_event_type = ( [self.trigger_event_type] if isinstance(self.trigger_event_type, str) else self.trigger_event_type ) segmenter = ns.dataloader.Segmenter( start=self.start, duration=self.duration, trigger_query=f"type in {trigger_event_type}", stride=self.stride, stride_drop_incomplete=self.stride_drop_incomplete, extractors=extractors, # type: ignore[arg-type] ) dataset = segmenter.apply(events) dataset.prepare() # Derive four independent RNG streams from ``self.seed`` so that each # consumer (train DataLoader shuffle + train worker base-seeds, train # WeightedRandomSampler multinomial draws, val worker base-seeds, # test worker base-seeds) is a pure function of its own sub-seed. # Per-split DataLoader generators matter when ``num_workers > 0``: # ``DataLoader.__iter__`` consumes one int64 from ``generator`` to # derive each worker's base seed, so sharing one generator across # splits would couple the train shuffle stream to how often Lightning # iterates val/test (sanity check, per-epoch validation, etc.). loader_gens: dict[str, torch.Generator | None] sampler_gen: torch.Generator | None worker_init_fn: tp.Callable[[int], None] | None if self.seed is None: LOGGER.info( "Data.seed=None; dataloader shuffling and weighted sampling " "will follow the caller's global torch RNG state." ) loader_gens = {"train": None, "val": None, "test": None} sampler_gen = None worker_init_fn = None else: train_state, sampler_state, val_state, test_state = np.random.SeedSequence( self.seed ).generate_state(4) loader_gens = { "train": torch.Generator().manual_seed(int(train_state)), "val": torch.Generator().manual_seed(int(val_state)), "test": torch.Generator().manual_seed(int(test_state)), } sampler_gen = torch.Generator().manual_seed(int(sampler_state)) worker_init_fn = seed_worker # Create the dataloaders loaders = {} for split in tqdm(["train", "val", "test"], desc="Preparing segments"): split_dataset = dataset.select(dataset.triggers.split == split) LOGGER.info(f"# {split} segments: {len(split_dataset)} \n") sampler = None if split == "train" and self.sampler is not None: sampler = self.sampler.build(split_dataset, generator=sampler_gen) persistent_workers = self.persistent_workers and self.num_workers > 0 loaders[split] = DataLoader( split_dataset, collate_fn=split_dataset.collate_fn, batch_size=self.batch_size, shuffle=split == "train" and sampler is None, sampler=sampler, num_workers=self.num_workers, drop_last=self.drop_last and split == "train", pin_memory=self.pin_memory, persistent_workers=persistent_workers, prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None, generator=loader_gens[split], worker_init_fn=worker_init_fn, ) return loaders