Source code for neuralset.dataloader

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=super-init-not-called

import collections
import concurrent.futures
import dataclasses
import logging
import typing as tp
import warnings

import numpy as np
import pandas as pd
import pydantic
import torch
from pydantic import PositiveFloat

import neuralset as ns

from . import base
from .events import EventTypesHelper
from .extractors import BaseExtractor
from .extractors import BaseExtractor as Feat
from .segments import find_incomplete_segments

logger = logging.getLogger(__name__)

T = tp.TypeVar("T", bound="SegmentsMixin")


@dataclasses.dataclass(kw_only=True)
class SegmentsMixin:
    """Mixin class for segment-structured classes.

    Subclasses must implement ``_subselect`` so that ``.select()`` works.
    """

    segments: tp.Sequence[ns.segments.Segment]

    def _duration_str(self) -> str:
        durs = {s.duration for s in self.segments}
        if len(durs) == 1:
            return f"{next(iter(durs)):.2f}s"
        return f"{min(durs):.2f}-{max(durs):.2f}s"

    def __post_init__(self):
        if not isinstance(self.segments, list):
            raise TypeError("segments must be a list of Segment instances")

    def _subselect(self: T, idx: list[int]) -> T:
        """To be defined by subclass: how to subselect from a list of indices"""
        raise NotImplementedError

    def select(
        self: T,
        idx: int | list[int] | list[bool] | np.ndarray | pd.Series,
    ) -> T:
        """Subselect segments by integer indices or boolean mask.

        Parameters
        ----------
        idx
            One of:
            - ``int``: single segment index
            - ``list[int]`` or ``np.ndarray[int]``: positional indices
            - ``list[bool]``, ``np.ndarray[bool]``, or ``pd.Series[bool]``:
              boolean mask (must have length equal to the number of segments)
        """
        if isinstance(idx, pd.Series):
            idx = idx.to_numpy()  # type: ignore[assignment]
        indices = np.atleast_1d(idx)
        if indices.ndim != 1:
            raise ValueError(f"select indices must be 1-d, got shape {indices.shape}")
        if indices.dtype == bool:
            if len(indices) != len(self):
                raise ValueError(
                    f"Boolean mask length {len(indices)} != "
                    f"number of segments {len(self)}"
                )
            indices = np.where(indices)[0]
        elif indices.dtype.kind not in ("i", "u"):
            raise ValueError(f"select indices must be int or bool, got {indices.dtype}")
        if not len(indices):
            raise ValueError("Empty subselection")
        return self._subselect(indices.tolist())

    @property
    def triggers(self) -> pd.DataFrame:
        triggers = [s.trigger for s in self.segments]
        if any(t is None for t in triggers):
            raise RuntimeError(f"Segments did not all have a trigger: {triggers}")
        df = pd.DataFrame([t.to_dict() for t in triggers])  # type: ignore
        df.index = df.pop("Index")  # type: ignore
        return df

    @property
    def events(self) -> pd.DataFrame:
        """All of the events spanned by all the segments"""
        # TODO FIXME why are segments events with both an index AND and Index colum
        # TODO discuss should the trigger event be in events? Probably not
        return pd.concat([seg.events for seg in self.segments]).drop_duplicates()

    def __len__(self) -> int:
        return len(self.segments)


[docs] @dataclasses.dataclass(kw_only=True) class Batch(SegmentsMixin): """A collection of extracted features for a list of segments. Returned by the :class:`SegmentDataset` or ``DataLoader``, this object holds both the underlying segments and the computed tensor data ready for machine learning models. Attributes ---------- data : dict of str to torch.Tensor The extracted feature tensors. Keys match the names of the extractors passed to the Segmenter. Tensors are always batched along the first dimension. segments : list of Segment The :class:`~neuralset.segments.Segment` instances corresponding to this batch. Useful for debugging or custom logic. """ data: dict[str, torch.Tensor] def __post_init__(self) -> None: super().__post_init__() if not isinstance(self.data, dict): raise TypeError(f"'data' needs to be a dict, got: {type(self.data)}") if not self.data: raise ValueError(f"No data in {self}") # check batch dimension batch_size = next(iter(self.data.values())).shape[0] if len(self.segments) != batch_size: raise RuntimeError( f"Incoherent batch size {batch_size} for {len(self.segments)} segments in {self}" )
[docs] def to(self, device: str) -> "Batch": """Creates a new instance on the appropriate device""" out = {name: d.to(device) for name, d in self.data.items()} return Batch(data=out, segments=self.segments)
# pylint: disable=unused-argument def __getitem__(self, key: str) -> None: raise RuntimeError("Batch is not a dict, use batch.data instead") def __repr__(self) -> str: shapes = "; ".join( f"{k}: {', '.join(str(d) for d in v.shape)}" for k, v in self.data.items() ) n = len(self.segments) return f"Batch({n} segments, {self._duration_str()}, shapes: {shapes})" def _subselect(self, idx: list[int]): """subselect the dataset through index or query""" segments = [self.segments[i] for i in idx] data = {key: d[idx] for key, d in self.data.items()} return self.__class__(data=data, segments=segments)
def validate_extractors(extractors: tp.Mapping[str, Feat]) -> tp.Mapping[str, Feat]: """Validate the extractor container provided as input and map all cases to the more general a dict of list of sequences of extractors """ if not extractors: return {} # use extractor names for list if not isinstance(extractors, collections.abc.Mapping): raise ValueError(f"Only dict of extractors are supported, got {type(extractors)}") # single extractors are mapped to list to unify all cases return extractors def prepare_extractors( extractors: tp.Sequence[BaseExtractor] | tp.Mapping[str, BaseExtractor], events: tp.Any, ) -> None: """Prepare the extractors using slurm in parallel, and the others sequentially. Parameters ---------- extractors: list/dict of extractors the extractors to prepare events: DataFrame or list of events/segments The structure containing all the events, be it as a dataframe or list of events or list of segments. """ from .events.utils import extract_events events = extract_events(events) extractor_list = list( extractors.values() if isinstance(extractors, dict) else extractors ) extractors_using_slurm = [ extractor for extractor in extractor_list if hasattr(extractor, "infra") and getattr(extractor.infra, "cluster", None) == "slurm" ] other_extractors = [ extractor for extractor in extractor_list if extractor not in extractors_using_slurm ] slurm_names = ", ".join( extractor.__class__.__name__ for extractor in extractors_using_slurm ) with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] for extractor in extractors_using_slurm: futures.append(executor.submit(extractor.prepare, events)) # for debugging: futures[-1].__dict__["_name"] = extractor.__class__.__name__ if extractors_using_slurm: msg = f"Started parallel preparation of extractors {slurm_names} on slurm" logger.info(msg) for extractor in other_extractors: logger.info(f"Preparing extractor: {extractor.__class__.__name__}") extractor.prepare(events) for future in concurrent.futures.as_completed(futures): try: future.result() # Raise any exception from the task except Exception as e: name = future.__dict__.get("_name", "UNKNOWN") logger.warning("Error occurred while preparing extractor %s: %s", name, e) raise def _get_pad_lengths( extractors: tp.Mapping[str, Feat], pad_duration: float | None, # in seconds ) -> dict[str, int]: """Precompute pad length in samples for each extractor if applicable extractors: mapping of Extractors the extractors pad_duration: float or None padding duration in seconds (if any) """ pad_lengths: dict[str, int] = {} if pad_duration is None: return pad_lengths for name, f in extractors.items(): if isinstance(f, Feat): freq = base.Frequency(f.frequency) pad_lengths[name] = freq.to_ind(pad_duration) return pad_lengths def _pad_to(tensor: torch.Tensor, pad_len: int | None): """Pad last dimension to a given length""" if pad_len is None: return tensor if pad_len < tensor.shape[-1]: msg = "Pad duration is shorter than segment duration, cropping." warnings.warn(msg, UserWarning) return tensor[:, :pad_len] else: return torch.nn.functional.pad(tensor, (0, pad_len - tensor.shape[-1]))
[docs] @dataclasses.dataclass class SegmentDataset(torch.utils.data.Dataset[Batch], SegmentsMixin): """Dataset defined through :class:`~neuralset.segments.Segment` instances and :class:`~neuralset.extractors.BaseExtractor` instances. Parameters ---------- extractors: dict of :class:`~neuralset.extractors.BaseExtractor` extractors to be computed, returned in the Batch.data dictionary items segments: list of :class:`~neuralset.segments.Segment` the list of segment instances defining the dataset pad_duration: float | tp.Literal["auto"] | None pad the segments to the maximum duration or to a specific duration None: no padding. Will throw error if segment durations vary. "auto": will pad with the max(segments.duration) remove_incomplete_segments: bool remove segments which do not contain events for one of the extractors transforms: dict, optional Map of extractor names to transforms (callables transforming the extractor tensor). If an extractor name is not present, no transform is applied. Keys must be a subset of the extractor names. Usage ----- .. code-block:: python extractors = {"whatever": ns.extractors.Pulse()} ds = ns.SegmentDataset(extractors, segments) # one data item item = ds[0] assert item.data["whatever"].shape[0] == 1 # batch dimension is always added # through dataloader: dataloader = torch.utils.data.DataLoader(ds, collate_fn=ds.collate_fn, batch_size=2) batch = next(iter(dataloader)) print(batch.data["whatever"]) # batch.segments holds the corresponding segments """ def __init__( self, extractors: tp.Mapping[str, Feat], segments: tp.Sequence[ns.segments.Segment], *, remove_incomplete_segments: bool = False, pad_duration: float | tp.Literal["auto"] | None = None, transforms: dict[str, tp.Callable] | None = None, ) -> None: super().__init__(segments=segments) self.extractors = validate_extractors(extractors) self.pad_duration = pad_duration self.remove_incomplete_segments = remove_incomplete_segments self.segments = _remove_incomplete_segments( list(segments), extractors, remove_incomplete_segments ) self._pad_lengths: None | dict[str, int] = None transforms = transforms or {} additional = set(transforms) - set(extractors) if additional: raise ValueError(f"Keys in transforms are not present in data: {additional}") self.transforms = transforms # Hold strong refs to _EventStore instances so they survive pickling # to spawn workers (store.__setstate__ self-registers in the worker). estores = {seg._store_id: seg._store_ref for seg in self.segments} self._event_stores = list(estores.values()) def __getstate__(self) -> dict[str, tp.Any]: state = self.__dict__.copy() # Lightweight segments for spawn-worker pickle: copy() produces # cache-free clones so we don't serialize per-segment event lists. state["segments"] = [seg.copy() for seg in self.segments] return state def prepare(self) -> None: prepare_extractors(self.extractors, self.segments)
[docs] def collate_fn(self, batches: list[Batch]) -> Batch: """Creates a new instance from several by stacking in a new first dimension for all attributes """ if not batches: return Batch(data={}, segments=[]) if len(batches) == 1: return batches[0] if not batches[0].data: raise ValueError(f"No extractor in first batch: {batches[0]}") # move everything to pytorch if first one is numpy extractors = {} for name in batches[0].data: data = [b.data[name] for b in batches] try: extractors[name] = torch.cat(data, axis=0) # type: ignore except Exception: string = f"Failed to collate data with shapes {[d.shape for d in data]}\n" logger.warning(string) raise segments = [s for b in batches for s in b.segments] return Batch(data=extractors, segments=segments)
def _check_padding(self) -> None: # check if padding is needed if self._pad_lengths is not None: return if self.pad_duration is None: if len(set([s.duration for s in self.segments])) > 1: msg = "Segments have different durations, so they cannot be collated into batches." msg += " Set `pad_duration` to `auto` to pad the segments to the maximum duration." raise ValueError(msg) pad_duration = self.pad_duration elif self.pad_duration == "auto": pad_duration = max([s.duration for s in self.segments]) else: pad_duration = self.pad_duration self._pad_lengths = _get_pad_lengths(self.extractors, pad_duration) def __len__(self) -> int: return len(self.segments) def __getitem__(self, idx: int | slice) -> Batch: if not isinstance(idx, (int, slice)): raise ValueError(f"idx must be int or slice, got {type(idx)}") self._check_padding() if isinstance(idx, slice): indices = list(range(len(self))[idx]) if not indices: return self.collate_fn([]) return self._subselect(indices).load_all() assert isinstance(self._pad_lengths, dict) # for mpy seg = self.segments[idx] events = seg.ns_events out: dict[str, torch.Tensor] = {} for name, extractors in self.extractors.items(): data = extractors( events, start=seg.start, duration=seg.duration, trigger=seg.trigger, ) # pad if need be data = _pad_to(data, self._pad_lengths.get(name, None)) # append to specific extractor list out[name] = data[None, ...] # add back dimension and set segment_data = Batch(data=out, segments=[seg]) for key in segment_data.data: if key in self.transforms: segment_data.data[key] = self.transforms[key](segment_data.data[key]) return segment_data
[docs] def build_dataloader(self, **kwargs: tp.Any) -> torch.utils.data.DataLoader: """Returns a dataloader for this dataset""" self._check_padding() return torch.utils.data.DataLoader(self, collate_fn=self.collate_fn, **kwargs)
[docs] def load_all(self, num_workers: int = 0) -> Batch: """Returns a single batch with all the dataset data, un-shuffled""" num_workers = min(num_workers, len(self)) batch_size = len(self) if num_workers > 1: batch_size = max(1, len(self) // (3 * num_workers)) if num_workers == 1: num_workers = 0 # simplifies debugging loader = self.build_dataloader( num_workers=num_workers, batch_size=batch_size, shuffle=False, ) return self.collate_fn(list(loader))
[docs] def as_one_batch(self, num_workers: int = 0) -> Batch: """Deprecated: use :meth:`load_all` instead.""" warnings.warn( "as_one_batch has been renamed to load_all", DeprecationWarning, stacklevel=2, ) return self.load_all(num_workers=num_workers)
def __repr__(self) -> str: ext = ", ".join(self.extractors) or "none" n = len(self.segments) return f"SegmentDataset({n} segments, {self._duration_str()}, extractors: {ext})" def _subselect(self, idx: list[int]): """Subselect the dataset through trigger index or query""" segments = [self.segments[i] for i in idx] return self.__class__( extractors=self.extractors, segments=segments, pad_duration=self.pad_duration, # TODO this should be done in the dataset builder remove_incomplete_segments=self.remove_incomplete_segments, transforms=self.transforms, )
[docs] class Segmenter(base.BaseModel): """Build a :class:`SegmentDataset` from an events DataFrame and extractors. Parameters ---------- extractors: dict of :class:`~neuralset.extractors.BaseExtractor` extractors to be computed, returned in the Batch.data dictionary items start: float Start time (in seconds) of the segment, with respect to the :term:`trigger` event (or stride). E.g. use -1.0 if you want the segment to start 1s before the event. duration: optional float Duration (in seconds) of the segment (defaults to event duration if ``trigger_query`` is used to extract segments based on specific events). trigger_query: optional Query Dataframe query selecting which events act as :term:`triggers <trigger>` — segments are time-locked to the matching events (see :data:`base.Query`). At least one of ``trigger_query`` or ``stride`` must be provided. stride: optional float Stride (in seconds) to use to define sliding window segments. stride_drop_incomplete: optional bool If True and stride is not None, drop segments that are not fully contained within the (start, stop) block. padding: optional float | tp.Literal["auto"] | None pad the segments to the maximum duration or to a specific duration. None: no padding. Will throw error if segment durations vary. "auto": will pad with the max(segments.duration) drop_incomplete: bool remove segments which do not contain events for one of the extractors drop_unused_events: bool remove events not used by the extractors before creating the segments. Usage ----- .. code-block:: python extractors = {"whatever": ns.extractors.Pulse()} segmenter = ns.Segmenter(extractors=extractors) dset = segmenter.apply(events) # one data item item = dset[0] assert item.data["whatever"].shape[0] == 1 # batch dimension is always added # through dataloader: dataloader = torch.utils.data.DataLoader(dset, collate_fn=dset.collate_fn, batch_size=2) batch = next(iter(dataloader)) print(batch.data["whatever"]) # batch.segments holds the corresponding segments """ # segments start: float = 0.0 duration: PositiveFloat | None trigger_query: base.Query | None = None stride: PositiveFloat | None = None stride_drop_incomplete: bool = True # extractors extractors: dict[str, BaseExtractor] # dataset padding: float | None = None drop_incomplete: bool = False drop_unused_events: bool = True # model_config = pydantic.ConfigDict(extra="forbid") def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) if self.trigger_query is None and self.stride is None: raise ValueError("At least one of trigger_query or stride must be provided.") def apply(self, events: pd.DataFrame) -> SegmentDataset: # Segment the events based on stride and/or triggers if self.trigger_query is not None: trigger_idx = pd.Series(events.query(self.trigger_query).index) if not len(trigger_idx): raise RuntimeError( f"the trigger query: {self.trigger_query} led to an empty set." ) else: trigger_idx = pd.Series(dtype=int) # TODO fixme class typing: self.duration cannot be trigger duration = None if self.duration == "trigger" else self.duration # Drop events not used by extractors or triggers if self.drop_unused_events: event_types = [] for extractor in self.extractors.values(): event_types.extend(EventTypesHelper(extractor.event_types).names) event_types.extend(events.loc[trigger_idx].type.unique()) events = events.loc[events.type.isin(event_types)] segments = ns.segments.list_segments( events=events, triggers=trigger_idx, start=self.start, duration=duration, stride=self.stride, stride_drop_incomplete=self.stride_drop_incomplete, ) segments = _remove_incomplete_segments( segments, self.extractors, self.drop_incomplete ) if not segments: raise RuntimeError(f"empty segments with {self!r}") ds = SegmentDataset( extractors=self.extractors, segments=segments, pad_duration=self.padding, remove_incomplete_segments=False, ) return ds
def _remove_incomplete_segments( segments: list[ns.segments.Segment], extractors: tp.Mapping[str, Feat], drop_incomplete: bool, ) -> list[ns.segments.Segment]: """Check that each segment has a least the events that correspond to each extractor""" # List event types required across extractors event_types = collections.defaultdict(list) for extractor in extractors.values(): if extractor.aggregation == "trigger": continue for event_type in extractor._event_types_helper.classes: event_types[event_type].append(extractor) # Identify problematic segments invalid_indices: set[int] = set() for event_type, extracts in event_types.items(): # Find segments that don't have this event type invalids = find_incomplete_segments(segments, [event_type]) if invalids: # Checks whether the extractor authorize missing event required = any([not f.allow_missing for f in extracts]) # Raise error if extractor requires this event type msg = f"{len(invalids)} segments are missing events of type {event_type.__name__}." if not drop_incomplete and required: msg += " Use `drop_incomplete=True` to remove them," msg += f" or set {extracts} with `allow_missing=True`." raise ValueError(msg) if not drop_incomplete and set(range(len(segments))): msg = f" . {extracts} allow_missing, however there are" msg += f"no segments with {event_type.__name__}. Missing values cannot be guessed." else: msg += ( " They will be populated with default missing values through prepare." ) logger.info(msg) invalid_indices.update(invalids) # Drop invalid segments if drop_incomplete and invalid_indices: msg = f"Removing {len(invalid_indices)} segments out of {len(segments)}" logger.info(msg) segments = [s for i, s in enumerate(segments) if i not in sorted(invalid_indices)] return segments