Source code for neuralset.extractors.base

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import typing as tp
import warnings

import numpy as np
import pandas as pd
import pydantic
import torch
from exca.base import DEFAULT_CHECK_SKIPS

import neuralset as ns
from neuralset import base
from neuralset.base import Frequency as Frequency
from neuralset.base import TimedArray as TimedArray
from neuralset.events import Event, EventTypesHelper, etypes
from neuralset.segments import Segment

logger = logging.getLogger(__name__)


[docs] class BaseExtractor(base._Module, base.NamedModel): """Base class for extracting features from :term:`events <event>` within a :class:`~neuralset.segments.Segment`. Subclasses define **what** data to extract (e.g. audio embeddings, EEG signals) while ``BaseExtractor`` handles event selection, temporal alignment, and multi-event aggregation. To create a custom extractor, subclass ``BaseExtractor`` and implement: * ``_get_data(events)`` — expensive per-event computation (typically cached via ``exca.MapInfra``). * ``_get_timed_arrays(events, start, duration)`` — return an iterable of :class:`~neuralset.base.TimedArray`, one per event. For extractors that produce a single static value per event (no time dimension), subclass :class:`BaseStatic` instead and override :meth:`~BaseStatic.get_static`. Parameters ---------- event_types : str or tuple of str Event type name(s) this extractor operates on (e.g. ``"Audio"`` or ``("Image", "Text")``). Must be set as a class-level default in every concrete subclass. aggregation : str Strategy for combining values when multiple matching events fall inside the same segment: * ``"single"`` — at most one event per output sample (raises on collision). * ``"sum"`` / ``"mean"`` — element-wise sum or mean. * ``"first"`` / ``"middle"`` / ``"last"`` — pick one event. * ``"cat"`` — concatenate along the first dimension. * ``"stack"`` — stack along a new first dimension. * ``"trigger"`` — use only the trigger event passed to ``__call__``. allow_missing : bool If True, return a zero tensor when no matching event is found in the segment instead of raising. Requires that :meth:`prepare` has been called first so the output shape is known. frequency : float or ``"native"`` Output sampling rate in Hz. Use ``"native"`` to keep the original sampling rate of the input data. ``0`` is reserved for static extractors (:class:`BaseStatic`). """ event_types: str | tuple[str, ...] = "" # eg: event_types: str | tuple[str, ...] = ("Image", "Text") aggregation: tp.Literal[ "single", "sum", "mean", "first", "middle", "last", "cat", "stack", "trigger", ] = "single" # builds output even when no corresponding event is provided allow_missing: bool = False frequency: float | tp.Literal["native"] = 0.0 # internal _CLASSES: tp.ClassVar[dict[str, tp.Type["BaseExtractor"]]] = {} _effective_frequency: float | None = pydantic.PrivateAttr(None) _event_types_helper: EventTypesHelper = pydantic.PrivateAttr() _missing_default: torch.Tensor | None = pydantic.PrivateAttr(None) @classmethod def __pydantic_init_subclass__(cls, **kwargs: tp.Any) -> None: super().__pydantic_init_subclass__(**kwargs) # check params super().__init_subclass__() # add event requirements to the extractor requirements if not cls._can_be_instantiated(): return name = cls.__name__ BaseExtractor._CLASSES[name] = cls model_fields: dict[str, pydantic.FieldInfo] = cls.model_fields # type: ignore event_types: tp.Any = model_fields["event_types"].default # type:ignore if not event_types: msg = f"Default event_types must be specified for {cls.__name__}" raise RuntimeError(msg) # security checks for new _get_data legafuncs = [ "_get_latents", "_get_latent", "_get_preprocessed_data", "_events_to_data", "_get_channel_positions", ] for func in legafuncs: if hasattr(cls, func): msg = f'In {name!r}, found function {func!r} which should be renamed to "_get_data"' raise RuntimeError(msg) infrafield = cls.model_fields.get("infra", None) if infrafield is not None: funcname = infrafield.default._infra_method.method.__name__ if funcname != "_get_data": msg = f'In {name!r}, found infra decorating {funcname!r} it should be "_get_data" by convention' raise RuntimeError(msg) # security checks for new event_types if not isinstance(event_types, str): is_tuple = isinstance(event_types, tuple) if not (is_tuple and all(isinstance(d, str) for d in event_types)): msg = f"In {name!r}, event_types attribute must be a string " msg += f"or tuple of string, got {event_types}" raise TypeError(msg) type_helper = EventTypesHelper(event_types) for etype in type_helper.classes: cls.requirements = cls.requirements + etype.requirements def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) self._event_types_helper = EventTypesHelper(self.event_types) name = self.__class__.__name__ if self.frequency != "native" and self.frequency < 0.0: msg = f"{name}.frequency is neither 'native' nor >= 0 (got {self.frequency})." raise ValueError(msg) if not (self.frequency or isinstance(self, BaseStatic)): msg = f"{name}.frequency=0 is only allowed for static extractors (did you mean 'native'?)" raise ValueError(msg) def _exclude_from_cache_uid(self) -> list[str]: # extractor convention from inheriting cache uid exclusion list return ["aggregation", "allow_missing"]
[docs] def prepare( self, obj: pd.DataFrame | tp.Sequence[Event] | tp.Sequence[Segment] ) -> None: """Pre-compute and cache extractor data for a collection of events. This method triggers ``_get_data`` on every matching event so that expensive computation (e.g. model inference) is done once and cached. It then calls the extractor on a single event to populate the output shape, which is needed when ``allow_missing=True``. Call ``prepare`` before using the extractor in a dataloader. Parameters ---------- obj : DataFrame or sequence of Event or sequence of Segment The structure containing the events. When calling ``prepare`` on several objects, prefer passing a list of events or segments over a DataFrame to avoid redundant conversion overhead. """ events = self._event_types_helper.extract(obj) if self.frequency == "native" and events and hasattr(events[0], "frequency"): freqs = set(e.frequency for e in events) # type: ignore cls = self.__class__.__name__ if len(freqs) > 1: msg = f"frequency='native' in {cls} with several different frequencies: {freqs}" msg += "\n(all data will not be processing at the same frequency, " msg += "should you set the extractor frequency?" logger.warning(msg) elif len(freqs) == 1: cls = self.__class__.__name__ freq = list(freqs)[0] msg = f"Processing to native frequency in {cls}.prepare: {freq}Hz" logger.info(msg) self._get_data(events) if events: # run extractor on 1 event to populate shape self( events[0], start=events[0].start, duration=0.001, trigger=events[0], )
def _get_data(self, events: list[Event]) -> tp.Iterable[tp.Any]: """Put heavy computation steps here, and cache the result using exca.MapInfra""" for _ in events: yield None raise NotImplementedError def _get_timed_arrays( self, events: list[Event], start: float, duration: float ) -> tp.Iterable[TimedArray]: raise NotImplementedError def __call__( self, events: tp.Any, # too complex: pd.DataFrame | list | dict | ns.events.Event, start: float, duration: float, trigger: Event | pd.Series | dict | None = None, ) -> torch.Tensor: """events: the single event (dict | ns.events.Event) or the series of events (list of Events | pd.DataFrame) describing the events, each containing start and duration. start: the start of the segment in the same timeline as the event. duration: the duration of the segment. """ from neuralset.events.utils import extract_events # Check input start, duration = float(start), float(duration) if duration < 0.0: raise ValueError(f"duration must be >= 0, got {duration}") if trigger is not None and not isinstance(trigger, Event): triggers = extract_events(trigger) if len(triggers) != 1: msg = f"trigger must be a single event, got {len(triggers)} from {trigger!r}" raise TypeError(msg) trigger = triggers[0] ns_events = self._get_relevant_events( events, trigger, start=start, duration=duration ) # If there is no valid event, return default values or raise error. if not ns_events: if self.allow_missing and self._missing_default is not None: # Return tensor with default values if you can. if self._effective_frequency is not None: default = self._missing_default ef = Frequency(self._effective_frequency) if ef: n_times = max(1, ef.to_ind(duration)) reps = [1 for _ in range(default.ndim)] + [n_times] default = default.unsqueeze(-1).repeat(reps) return default else: msg = f"_missing_default was set for {self.name} but _effective_frequency is missing" raise RuntimeError(msg) # Raise if you can't return default values. event_types = self._event_types_helper.classes found_types = {type(e) for e in extract_events(events)} msg = f"No {event_types} found in segment for extractor {self.name} " msg += f"(types found: {found_types} in {events!r}) " if not self.allow_missing: msg += ( f"(filter invalid segments or set allow_missing=True to {self.name})" ) else: msg += "and extractor shape not populated " msg += '(you may need to call "prepare" on the extractor).' raise ValueError(msg) # Main part: get data from each event tarrays = list( self._get_timed_arrays(events=ns_events, start=start, duration=duration) ) # Store frequency (TimedArray frequency for native, else declared frequency) freq = float( tarrays[0].frequency if self.frequency == "native" else self.frequency ) if self._effective_frequency is None: self._effective_frequency = freq # Align and combine the data of all events in the segment tensor = self._tarrays_to_tensor( tarrays, start, duration, freq, self.aggregation, ) # Store default value if self._missing_default is None: # last dimension is time if frequency is not 0 shape = tuple(tensor.shape[: -1 if freq else None]) self._missing_default = torch.zeros(*shape, dtype=tensor.dtype) return tensor def _get_relevant_events( self, events: tp.Any, trigger: Event | None, *, start: float, duration: float, ) -> list[Event]: """Select the extractor-relevant events for the segment window and aggregation.""" # if trigger-aggregation, we only extract data from the trigger event if self.aggregation == "trigger": event_types = self._event_types_helper.classes if not isinstance(trigger, event_types): aggregation = self.aggregation msg = f"Extractor {self.name} has {aggregation=} but trigger is {trigger!r} (not {event_types})" raise ValueError(msg) events = [trigger] # make sure events is list[Event] ns_events = self._event_types_helper.extract(events) if ns_events and len(timelines := {e.timeline for e in ns_events}) > 1: msg = f"Multiple timelines {timelines} in events passed to extractor {self!r}: {ns_events}" raise ValueError(msg) if self.aggregation != "trigger": stop = start + duration in_window = [ e for e in ns_events if e.start < stop and e.start + e.duration > start ] # fall back to full list when nothing overlaps, so raw-data events # (Meg, Audio, Fmri) called outside their nominal range still reach # the zero-fill in _tarrays_to_tensor (see test_first_samp). ns_events = in_window or ns_events if self.aggregation in ("first", "trigger"): ns_events = ns_events[:1] elif self.aggregation == "last": ns_events = ns_events[-1:] elif self.aggregation == "middle": ns_events = [ns_events[len(ns_events) // 2]] return ns_events @staticmethod def _tarrays_to_tensor( tarrays: list[TimedArray], start: float, duration: float, frequency: float, aggregation: str, ) -> torch.Tensor: """Combine a list of time array into a torch Tensor.""" # picker modes: keep the chosen tarray and accumulate as sum if aggregation in ("trigger", "first", "middle", "last"): if len(tarrays) != 1: msg = f"Expected 1 tarray for {aggregation=!r}, got {tarrays}" raise RuntimeError(msg) if aggregation == "trigger" and not frequency: tarrays[0].start = start # static trigger: fake an overlap aggregation = "sum" segment_info: dict[str, tp.Any] = { "start": start, "frequency": frequency, "duration": duration, } if aggregation in ("sum", "mean", "single"): tarray = TimedArray(aggregation=aggregation, **segment_info) # type: ignore for ta in tarrays: tarray += ta data = tarray.data elif aggregation in ("cat", "stack"): arrays = [] for ta in tarrays: buf = TimedArray(**segment_info) buf += ta arrays.append(buf.data) combine = np.concatenate if aggregation == "cat" else np.stack data = combine(arrays, axis=0) else: raise RuntimeError(f"unknown aggregation: {aggregation}") tensor = torch.from_numpy(data) if not tensor.ndim: tensor = tensor.unsqueeze(0) return tensor
[docs] class BaseStatic(BaseExtractor): """Base class for extractors that produce one feature vector per event. Subclasses implement :meth:`get_static` instead of the full dynamic pipeline. ``frequency`` defaults to 0 (no time axis). """ frequency: float = 0.0 # FIXME should be pydantic.PositiveFloat
[docs] def get_static(self, event: etypes.Event) -> torch.Tensor: """Return a single feature vector for the given event. Override this method in subclasses to produce a static (non-temporal) embedding for one event. The returned tensor should have no time dimension — temporal wrapping is handled by ``BaseStatic`` automatically. Parameters ---------- event : Event The event to extract a feature from. Returns ------- torch.Tensor A tensor of shape ``(*feature_shape,)`` (no time axis). """ raise NotImplementedError
def _get_timed_arrays( self, events: list[Event], start: float, duration: float ) -> tp.Iterable[TimedArray]: for event in events: embedding = self.get_static(event) ta = TimedArray( frequency=0, duration=event.duration, start=event.start, data=embedding.numpy(), ) yield ta
# pylint: disable=unused-argument def _skip_new_event_types(key, val, prev): if "event_types" in key and not prev: return True return False DEFAULT_CHECK_SKIPS.append(_skip_new_event_types)
[docs] class Pulse(BaseStatic): """Constant-one extractor — returns a single 1.0 scalar for every event. Useful as a bias term or baseline feature in models that expect at least one extractor per event type. """ event_types: str | tuple[str, ...] = "Event"
[docs] def get_static(self, event: etypes.Event) -> torch.Tensor: return torch.ones(1, dtype=torch.float32)
[docs] class EventField(BaseStatic): """Extractor which extracts an int or float attribute from an event. `event_field` can be either an attribute of the event or a key in the event.extra dictionary. Parameters ---------- event_types : str or tuple of str Type of event(s) to apply this extractor to. event_field : str Field to extract from the event. """ event_types: str | tuple[str, ...] = "Event" event_field: str _dtype: torch.dtype | None = pydantic.PrivateAttr(None) def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) if self.event_types == "Event": warnings.warn( f"{self.name}: event_types has not been set, are you sure you want to apply this extractor to all events?" ) def _get_field_values( self, obj: pd.DataFrame | tp.Sequence[Event] | tp.Sequence[Segment], ) -> set[tp.Any]: events = self._event_types_helper.extract(obj) if not events: raise ValueError(f"No events found for {self.name}") return set(e._get_field_or_extra(self.event_field) for e in events)
[docs] def prepare( self, obj: pd.DataFrame | tp.Sequence[Event] | tp.Sequence[Segment] ) -> None: values = self._get_field_values(obj) types = [type(v) for v in values] if len(set(types)) > 1: raise ValueError( f"Field {self.event_field} has different types in events for {self.__class__.__name__}: {types}" ) # Enforce int or float types int_types = (int, np.int32, np.int64) float_types = (float, np.float32, np.float64) if types[0] not in (int_types + float_types): msg = f"Field {self.event_field} has type {types[0]}, which is not supported." raise ValueError(msg) self._dtype = torch.float32 if types[0] in float_types else torch.long
[docs] def get_static(self, event: etypes.Event) -> torch.Tensor: if self._dtype is None: msg = f"{self.name}: Must call extractor.prepare(events) before using the extractor." raise ValueError(msg) return torch.tensor( [event._get_field_or_extra(self.event_field)], dtype=self._dtype, )
[docs] class LabelEncoder(EventField): """Encode a given field from an event, e.g. to be used as a label. Parameters ---------- event_types : str or tuple of str Type of event(s) to apply this extractor to. event_field : str Field to encode from the event. allow_missing : bool If True, allow missing events without raising errors. treat_missing_as_separate_class: bool If True, treat missing events as a separate class with index -1, or one-hot vector with last index set to 1. This is only relevant if allow_missing is True. Note: If using LabelEncoder for a multilabel classification task, set this to False for missing labels to be represented by a vector of all zeros. return_one_hot : bool If True, return one-hot representation of the index. Otherwise, return an int in [0, n_unique_values - 1] (or the corresponding values provided in ``predefined_mapping``, and ``-1`` for missing events if ``treat_missing_as_separate_class=True``). predefined_mapping : dict, optional If provided, use this mapping from label to index instead of computing it from data. Values must be >= 0. If ``return_one_hot=True``, these indices MUST be contiguous and start from 0. """ treat_missing_as_separate_class: bool = False return_one_hot: bool = False predefined_mapping: dict[str, int] | None = None _label_to_ind: dict[str, int] = {} _n_classes: int = 0 def model_post_init(self, log__): super().model_post_init(log__) if ( self.allow_missing and not self.return_one_hot and not self.treat_missing_as_separate_class ): msg = ( "Missing events will be encoded using the default all-zero value " "(for example, 0 or a zero vector/tensor), which may be indistinguishable " "from a valid class if that class is also mapped to zeros. " "Set treat_missing_as_separate_class=True to avoid this." ) logger.warning(msg) if self.predefined_mapping is not None: indices = set(self.predefined_mapping.values()) if min(indices) < 0: msg = f"predefined_mapping values must be >= 0, got {min(indices)}." raise ValueError(msg) expected_indices = set(range(min(indices), min(indices) + len(indices))) if indices != expected_indices: msg = f"Label indices are not contiguous. Expected indices: {expected_indices}, " msg += f"but got: {indices}. This may cause issues with one-hot encoding " msg += "or class-based operations." logger.warning(msg) if ( self.treat_missing_as_separate_class and "__missing__" in self.predefined_mapping ): msg = "Key '__missing__' is reserved when treat_missing_as_separate_class is True." raise ValueError(msg)
[docs] def prepare( self, obj: pd.DataFrame | tp.Sequence[Event] | tp.Sequence[Segment] ) -> None: labels = self._get_field_values(obj) # Build mapping on first call; subsequent calls are no-ops if labels are known mapping = self._label_to_ind or self.predefined_mapping if mapping: unknown = labels - mapping.keys() if unknown: msg = ( f"Labels {unknown} are not in the existing mapping " f"{sorted(mapping)}. If prepare() is called per-subject " "after an initial full-data prepare(), this means new " "labels appeared. Use predefined_mapping to fix the " "mapping upfront." ) raise ValueError(msg) if self._label_to_ind: return self._label_to_ind = dict(self.predefined_mapping) # type: ignore[arg-type] else: if len(labels) < 2: logger.warning( f"LabelEncoder has only found one label: {labels}. " "This was probably not intended." ) self._label_to_ind = {label: i for i, label in enumerate(sorted(labels))} self._n_classes = len(set(self._label_to_ind.values())) if self.treat_missing_as_separate_class: self._label_to_ind["__missing__"] = -1 # Register in dict for transparency self._n_classes += 1 if self.return_one_hot: self._missing_default = torch.nn.functional.one_hot( torch.tensor([self._n_classes - 1], dtype=torch.long), num_classes=self._n_classes, )[0] else: self._missing_default = torch.tensor([-1], dtype=torch.long) events = self._event_types_helper.extract(obj) self( events[0], events[0].start, duration=0.001, trigger=events[0], )
[docs] def get_static(self, event: etypes.Event) -> torch.Tensor: if not self._label_to_ind: msg = f"{self.name}: Must call extractor.prepare(events) before using the extractor." raise ValueError(msg) inds = [self._label_to_ind[event._get_field_or_extra(self.event_field)]] label = torch.tensor(inds, dtype=torch.long) if self.return_one_hot: # Remove the batch dim label = torch.nn.functional.one_hot(label, num_classes=self._n_classes)[0] return label
[docs] class EventDetector(BaseExtractor): """Extracts time-aligned extractors from event annotations. Parameters ---------- event_types: str the event type to detect (e.g., "Keystroke"). frequency: float sampling frequency in Hz. mode: str mode of labeling ("dense", "start", "center", "duration"). allow_missing: bool if True, missing events are allowed without raising errors. """ event_types: str = "Event" frequency: float = 100.0 mode: tp.Literal["dense", "start", "duration", "center"] = "dense" allow_missing: bool = True # freeze aggregation as it is bypassed aggregation: tp.Literal["sum"] = "sum" def _get_timed_arrays(self, events, start: float, duration: float): freq = Frequency(self.frequency) n_samples = max(1, freq.to_ind(duration)) data = np.zeros((1, n_samples), dtype=np.float32) for event in events: idx_start = freq.to_ind(event.start - start) idx_end = freq.to_ind(event.start + event.duration - start) idx_center = freq.to_ind(event.start + event.duration / 2 - start) value = 1.0 if self.mode == "dense": data[0, max(0, idx_start) : min(n_samples, idx_end)] = value continue elif self.mode == "start": idx = idx_start elif self.mode == "center": idx = idx_center elif self.mode == "duration": idx = idx_center value = min(event.duration / duration, 1.0) else: raise NotImplementedError(f"Mode {self.mode} is not implemented") if 0 <= idx < n_samples: data[0, idx] = value yield ns.base.TimedArray( data=data, start=start, duration=duration, frequency=self.frequency, )