Source code for neuralbench.data

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging

from torch.utils.data import DataLoader
from tqdm import tqdm

import neuralset as ns

from .transforms import (  # noqa: F401
    AddDefaultEvents,
    CropSleepRecordings,
    CropTimelines,
    OffsetEvents,
    PredefinedSplit,
    ShuffleTrainingLabels,
    SimilaritySplit,
    SklearnSplit,
    TextPreprocessor,
)
from .utils import make_weighted_sampler

LOGGER = logging.getLogger(__name__)


[docs] class Data(ns.BaseModel): """Create dataloaders for brain-modeling experiments.""" study: ns.Step neuro: ns.extractors.BaseExtractor target: ns.extractors.BaseExtractor channel_positions: ns.extractors.ChannelPositions # Segments trigger_event_type: str | list[str] start: float = -0.5 duration: float | None = 3 stride: float | None = None stride_drop_incomplete: bool = True # Dataloaders use_weighted_sampler: bool = False batch_size: int = 64 num_workers: int = 0 drop_last: bool = False pin_memory: bool = True persistent_workers: bool = True prefetch_factor: int | None = None # Others summary_columns: list[str] = [] _subject_id: ns.extractors.LabelEncoder | None = None def model_post_init(self, __context): super().model_post_init(__context) self._subject_id = ns.extractors.LabelEncoder( event_types=self.neuro.event_types, event_field="subject", return_one_hot=False, )
[docs] def prepare(self) -> dict[str, DataLoader]: """Load events, build extractors, segment data and return train/val/test DataLoaders. Returns ------- dict with keys ``"train"``, ``"val"``, ``"test"`` mapping to :class:`~torch.utils.data.DataLoader` instances. """ events = self.study.run() if "split" not in events.columns: LOGGER.error( "No `split` column found in events. Make sure splits are defined in the study, " "or use an events transform (`neuralset.events.transforms`) to add them." ) summary_columns = ["index", "subject"] + self.summary_columns + ["timeline"] summary_df = ( events.reset_index() .groupby(["study", "split", "type"], dropna=False)[summary_columns] .nunique() ) LOGGER.info("Dataset summary:\n%s", summary_df.to_string()) extractors = { "neuro": self.neuro, "target": self.target, "subject_id": self._subject_id, } if isinstance(self.neuro, ns.extractors.MneRaw): # Prepare the neuro extractor first because the channel positions depend on it self.neuro.prepare(events) channels = self.neuro._channels assert channels is not None channel_positions = self.channel_positions.build(self.neuro) LOGGER.info( f"Found {len(channels)} different channels: {list(channels.keys())}" ) extractors["channel_positions"] = channel_positions trigger_event_type = ( [self.trigger_event_type] if isinstance(self.trigger_event_type, str) else self.trigger_event_type ) segmenter = ns.dataloader.Segmenter( start=self.start, duration=self.duration, trigger_query=f"type in {trigger_event_type}", stride=self.stride, stride_drop_incomplete=self.stride_drop_incomplete, extractors=extractors, # type: ignore[arg-type] ) dataset = segmenter.apply(events) dataset.prepare() # Create the dataloaders loaders = {} for split in tqdm(["train", "val", "test"], desc="Preparing segments"): split_dataset = dataset.select(dataset.triggers.split == split) LOGGER.info(f"# {split} segments: {len(split_dataset)} \n") sampler = None if split == "train" and self.use_weighted_sampler: sampler = make_weighted_sampler(split_dataset, logger=LOGGER) persistent_workers = self.persistent_workers and self.num_workers > 0 loaders[split] = DataLoader( split_dataset, collate_fn=split_dataset.collate_fn, batch_size=self.batch_size, shuffle=split == "train" and sampler is None, sampler=sampler, num_workers=self.num_workers, drop_last=self.drop_last and split == "train", pin_memory=self.pin_memory, persistent_workers=persistent_workers, prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None, ) return loaders