Source code for neuralset.extractors.neuro

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import inspect
import logging
import typing as tp
from collections import defaultdict
from itertools import compress

import mne
import numpy as np
import pandas as pd
import pydantic
import sklearn.preprocessing
import torch
from exca import MapInfra
from exca.cachedict import DumpContext
from exca.helpers import DiscriminatedModel
from mne._fiff.pick import _VALID_CHANNEL_TYPES  # type: ignore
from tqdm import tqdm

import neuralset as ns
from neuralset import utils
from neuralset.base import TimedArray
from neuralset.events import etypes

from .base import BaseExtractor, BaseStatic

logger = logging.getLogger(__name__)
DataframeOrEventsOrSegments = (
    pd.DataFrame | tp.Sequence[etypes.Event] | tp.Sequence[ns.segments.Segment]
)

FSAVERAGE_SIZES = {
    "fsaverage3": 642,
    "fsaverage4": 2562,
    "fsaverage5": 10242,
    "fsaverage6": 40962,
    "fsaverage7": 163842,
}


def _overlap(
    start1: float,
    duration1: float,
    start2: float,
    duration2: float,
) -> tuple[float, float]:
    """
    Computes the overlap times between two windows
    """
    starts = (start1, start2)
    stops = tuple(s + d for s, d in zip(starts, (duration1, duration2)))
    start = max(starts)
    stop = min(stops)
    return start, max(0, stop - start)


@DumpContext.register
class MneTimedArray(TimedArray):
    """TimedArray subclass for MNE data with convenience accessors for channel info.

    Replaces MneRawFif: avoids 1-3s FIF open overhead per recording (seek/decompress).
    """

    header: dict[str, tp.Any]  # always set by from_native / __load_from_info__

    @classmethod
    def from_native(cls, raw: mne.io.Raw, start: float | None = None) -> "MneTimedArray":
        """Build from an mne.io.Raw, extracting channel header.

        Parameters
        ----------
        raw: mne.io.Raw
        start: optional float
            Timeline start. Defaults to ``raw.first_samp / sfreq``.
        """
        data = raw.get_data().astype(np.float32)
        if any("," in name for name in raw.ch_names):
            raise ValueError("Channel names must not contain commas")
        ch_locs = np.array([ch["loc"] for ch in raw.info["chs"]])
        header: dict[str, tp.Any] = {
            # comma-sep: 0.18ms init vs 0.73ms for JSON list
            "ch_names": ",".join(raw.ch_names),
            "ch_types": ",".join(raw.get_channel_types()),
            "ch_locs": ch_locs,
            "highpass": raw.info["highpass"],
            "lowpass": raw.info["lowpass"],
        }
        if start is None:
            start = raw.first_samp / raw.info["sfreq"]
        return cls(
            data=data,
            frequency=raw.info["sfreq"],
            start=start,
            header=header,
        )

    @property
    def ch_names(self) -> list[str]:
        # 0.31ms vs 5.41ms via to_info().ch_names (mne.create_info is ~3.3ms/call)
        return self.header["ch_names"].split(",")

    @property
    def ch_types(self) -> list[str]:
        return self.header["ch_types"].split(",")

    def to_info(self) -> mne.Info:
        """Reconstruct an ``mne.Info`` from the stored header.

        Includes ch_names, ch_types, sfreq, channel locations, highpass/lowpass.
        Rich fields (bads, projs, filter history, …) are not preserved.
        """
        info = mne.create_info(
            self.ch_names, sfreq=float(self.frequency), ch_types=self.ch_types
        )
        ch_locs: np.ndarray = self.header["ch_locs"]
        for i, ch in enumerate(info["chs"]):
            ch["loc"][:] = ch_locs[i]
        with info._unlock():
            info["highpass"] = self.header["highpass"]
            info["lowpass"] = self.header["lowpass"]
        return info

    def to_native(self) -> mne.io.RawArray:
        """Reconstruct a minimal ``mne.io.RawArray`` from the stored data and header."""
        first_samp = int(round(self.start * float(self.frequency)))
        return mne.io.RawArray(
            self.data, self.to_info(), first_samp=first_samp, verbose=False
        )


@DumpContext.register
class FmriTimedArray(TimedArray):
    """TimedArray subclass for fMRI with spatial metadata for plotting.

    Header always contains ``space`` (output coordinate system, e.g.
    ``"fsaverage5"`` after surface projection, or the event's original
    space when unprojected) and ``preproc``.

    Optional extra key: ``affine`` (4×4 ndarray) — present only for
    unprojected volumetric data, enables ``to_native()``.
    """

    header: dict[str, tp.Any]  # narrow down typing (not optional)

    def to_native(self, time_index: int | None = None) -> tp.Any:
        """Reconstruct a NIfTI image. Only valid for unprojected volumetric data."""
        affine = self.header.get("affine")
        if affine is None:
            raise ValueError(
                f"No affine in header={self.header}"
                " — data was projected or is surface input"
            )
        import nibabel

        data = np.asarray(self.data)
        if time_index is not None:
            data = data[..., time_index]
        return nibabel.Nifti1Image(data, affine)


[docs] class MneRaw(BaseExtractor): """ Feature extractor for raw MNE data files. This class handles loading, preprocessing, and caching of continuous MNE raw recordings. Extractor preparation and caching are executed with extractor.prepare() The steps of preprocessing, if specified, are ordered as follows: 1. Channel selection 2. Drop bad channels 3. Bipolar referencing 4. Notch filtering 5. Band-pass filtering 6. Hilbert transform 7. Resampling 8. Scaling 9. Applying projectors 10. Baseline correction (applied on segments) 11. Clamp (applied on segments) Parameters ---------- baseline : tuple of float, optional If provided as a tuple ``(start, end)``, defines the start and end times (in seconds) relative to the **beginning of the analysis window** — note this differs from MNE's convention, which defines baseline relative to the epoch onset. Used for baseline correction. picks : str or tuple of str Channels to pick from the raw data. Can be channel types (e.g. `'meg'`, `'eeg'`), channel names (e.g. `'MEG 0111'`), `"all"` for all channels, or `"data"` for data channels. Regular expressions are supported for selecting channel names automatically (e.g. `'MEG ...1'`). frequency : "native" or float, default="native" Target sampling frequency. If `"native"`, uses the frequency of the input recording. offset : float, default=0.0 Time offset (in seconds) to apply to the event (typically for aligning with the response) apply_proj : bool, default=False Whether to apply projectors stored in the MNE raw object. filter : tuple of (float or None, float or None), optional Band-pass filter limits as ``(l_freq, h_freq)``. If None, no band-pass filtering is applied. apply_hilbert : bool, default=False If True, applies the Hilbert transform to extract the signal envelope. notch_filter : float or list of float, optional Frequencies (in Hz) to apply a notch filter at. For a single frequency (as float) or a list of frequencies (as list of float), all harmonics of specified frequencies up to 300 Hz will be filtered out. drop_bads : bool, default=False Whether to drop channels marked as bad in the MNE info structure. mne_cpus : int, default=-1 Number of CPUs to use for multiprocessing in MNE operations. scaler : {"RobustScaler", "StandardScaler"}, optional Optional scaling strategy to normalize channel data using scikit-learn scalers. scale_factor : float, optional Optional multiplicative factor applied to the data after scaling, but before clamping. E.g, can be used to convert from V to mV or uV. clamp : float, optional Maximum absolute value for clamping the data after preprocessing. fill_non_finite : float or None, optional If a float, any non-finite values (NaN / +inf / -inf) found after preprocessing are replaced with this value and a warning is logged. If None (the default), no replacement is performed. bipolar_ref : tuple of (list of str, list of str), optional Explicit anode/cathode channel name lists for bipolar referencing via ``mne.set_bipolar_reference``. The first list contains anode names and the second list contains cathode names; they must have the same length. Applied after channel selection and dropping bad channels but before filtering. The original monopolar channels consumed by the pairs are removed and replaced with the new bipolar channels. channel_order: ["unique", "original"] `if self.channel_order=="original"` Assigns channel indices for each raw file (doesn't match channel names across files). Allows use of a subject layer of fixed dimension across subjects. Prevents building a too large channel dimension when many subjects. `else`: is default behavior (unique) channels are numbered based on unique names across all subjects. allow_maxshield : bool, default=False If True, allow processing of Elekta/MEGIN MEG data recorded with Internal Active Shielding (MaxShield). Such recordings contain compensation signals that should normally be removed via SSS/tSSS (MaxFilter) before analysis. Does not affect the cache uid. """ event_types: tp.Literal["Meg", "Eeg", "Emg", "Fnirs", "Ieeg"] = "Meg" frequency: tp.Literal["native"] | float = "native" offset: float = 0.0 baseline: tuple[float, float] | None = None picks: str | tuple[str, ...] = pydantic.Field(("data",), min_length=1) apply_proj: bool = False filter: tuple[float | None, float | None] | None = None apply_hilbert: bool = False notch_filter: float | list[float] | None = None drop_bads: bool = False mne_cpus: int = -1 infra: MapInfra = MapInfra( timeout_min=120, cpus_per_task=10, version="1", ) scaler: None | tp.Literal["RobustScaler", "StandardScaler"] = None scale_factor: float | None = None clamp: float | None = None fill_non_finite: float | None = None bipolar_ref: tuple[list[str], list[str]] | None = None channel_order: tp.Literal["unique", "original"] = "unique" allow_maxshield: bool = False _channels: dict[str, int] = {} @classmethod def _exclude_from_cls_uid(cls) -> list[str]: prev = super()._exclude_from_cls_uid() return prev + ["mne_cpus", "allow_maxshield"] def _exclude_from_cache_uid(self) -> list[str]: prev = super()._exclude_from_cache_uid() return prev + ["baseline", "offset", "scale_factor", "clamp"] def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) # check baseline if self.baseline is not None: issue = len(self.baseline) != 2 issue |= not all(isinstance(b, float) for b in self.baseline) issue |= self.baseline[1] <= self.baseline[0] if issue: msg = f"baseline must be None or 2 floats, got {self.baseline}" raise ValueError(msg) if self.bipolar_ref is not None: anodes, cathodes = self.bipolar_ref if len(anodes) != len(cathodes): raise ValueError( f"bipolar_ref anodes and cathodes must have equal length, " f"got {len(anodes)} and {len(cathodes)}" ) def prepare(self, obj: DataframeOrEventsOrSegments) -> None: """Specify how to load and preprocess the event. Can be overriden by user. """ events: list[etypes.MneRaw] events = self._event_types_helper.extract(obj) # type: ignore for ta in self._get_data(events): self._update_channels(ta.ch_names) if events: self(events[0], start=events[0].start, duration=0.001, trigger=events[0]) @staticmethod def _pick_channels( raw: mne.io.Raw, picks_or_regexp: str | tuple[str, ...] ) -> mne.io.Raw: if isinstance(picks_or_regexp, str) and picks_or_regexp not in [ "all", "data", "meg", "ref_meg", ] + list(_VALID_CHANNEL_TYPES): sel = mne.pick_channels_regexp(raw.ch_names, picks_or_regexp) picks = [raw.ch_names[i] for i in sel] else: picks = picks_or_regexp # type: ignore return raw.pick(picks, verbose=False) def _preprocess_raw(self, raw: mne.io.Raw, event: etypes.MneRaw) -> MneTimedArray: if raw.info.get("maxshield", False) and not self.allow_maxshield: raise ValueError( f"Data for {event!r} was recorded with Elekta MaxShield " "(Internal Active Shielding). These recordings contain " "compensation signals that corrupt brain data unless removed " "by SSS/tSSS (MaxFilter). Set allow_maxshield=True on the " "extractor to proceed anyway." ) raw = self._pick_channels(raw, self.picks) if self.drop_bads: raw.load_data() raw = raw.drop_channels(raw.info["bads"]) if self.bipolar_ref is not None: raw.load_data() anodes, cathodes = self.bipolar_ref raw = mne.set_bipolar_reference(raw, anodes, cathodes, verbose="WARNING") if self.notch_filter is not None: raw.load_data() raw = self._notch_filter(raw, self.notch_filter, self.mne_cpus) if self.filter is not None: raw.load_data() l_freq, h_freq = self.filter # Ignore lowpass filter if cutoff is higher than Nyquist frequency if h_freq is not None and h_freq >= raw.info["sfreq"] / 2: logger.warning( "Lowpass filter cutoff frequency is higher than or equal to the Nyquist frequency. " "Setting it to None." ) h_freq = None raw.filter(l_freq, h_freq, n_jobs=self.mne_cpus, verbose=False) if self.apply_hilbert: raw.load_data() raw = raw.apply_hilbert(envelope=True) if self.frequency not in ("native", event.frequency): raw.load_data() raw = raw.resample(float(self.frequency), n_jobs=self.mne_cpus, verbose=False) if self.scaler is not None: raw.load_data() scaler = getattr(sklearn.preprocessing, self.scaler)() raw._data = scaler.fit_transform(raw._data.T).T if self.apply_proj: raw.apply_proj() if self.fill_non_finite is not None: raw.load_data() if not np.all(np.isfinite(raw._data)): n_nonfinite = np.count_nonzero(~np.isfinite(raw._data)) val = self.fill_non_finite logger.warning( "Non-finite values (NaN/inf) detected in %s data for " "event %s (%d values). Replacing with %s before caching.", self.event_types, event, n_nonfinite, val, ) raw._data = np.nan_to_num(raw._data, nan=val, posinf=val, neginf=val) return MneTimedArray.from_native(raw) @infra.apply( item_uid=lambda e: e._splittable_event_uid(), exclude_from_cache_uid="method:_exclude_from_cache_uid", ) def _get_data(self, events: tp.Sequence[etypes.MneRaw]) -> tp.Iterator[MneTimedArray]: for event in events: yield self._preprocess_raw(event.read(), event) def _get_timed_arrays( self, events: list[etypes.MneRaw], start: float, duration: float ) -> tp.Iterable[TimedArray]: for event in events: yield self._get_timed_array(event, start, duration) def _get_timed_array( self, event: etypes.MneRaw, start: float, duration: float ) -> TimedArray: start += self.offset # Extend window in case of disjoint baseline window_start, window_stop = start, start + duration if self.baseline is not None: if self.baseline[0] >= self.baseline[1]: msg = f"unexpected baseline:{self.baseline}" raise RuntimeError(msg) window_start = min(window_start, start + self.baseline[0]) window_stop = max(window_stop, start + self.baseline[1]) ta = next(self._get_data([event])) freq = ta.frequency ch_names = ta.ch_names # Relocate cached data to the event's timeline position and extract window ta = ta.with_start(event.start) tdata = ta.overlap(start=window_start, duration=window_stop - window_start) tdata.data = np.asarray( tdata.data ) # materialize ContiguousMemmap before arithmetic if self.scale_factor is not None: tdata.data = tdata.data * self.scale_factor # Apply baseline to the data if self.baseline is not None: baseline_duration = self.baseline[1] - self.baseline[0] base = tdata.overlap(start + self.baseline[0], baseline_duration).data if base.size: tdata.data = tdata.data - base.mean(1, keepdims=True) tdata = tdata.overlap(start=start, duration=duration) # initialize output channel_idx = self._get_channels(ch_names) timed_out = TimedArray(frequency=freq, start=start, duration=duration) out_shape = (max(self._channels.values()) + 1, timed_out.data.shape[-1]) out = np.zeros(out_shape, dtype=np.float32) if tdata.start == start and tdata.duration == duration: timed_out = tdata # bypass copy for efficiency else: timed_out += tdata if self.clamp is not None: timed_out.data = np.clip(timed_out.data, a_min=-self.clamp, a_max=self.clamp) out[channel_idx, :] = timed_out.data timed_out.start -= self.offset timed_out.data = out return timed_out def _update_channels(self, ch_names: list[str]) -> None: """ Update the indices assigned to channels based on unique indices across the dataset or based on original (per raw file) channel order. Example:: `self.channel_order == "original"` Stack based on original channel order only: subject1: [a, b, c,]; subject2: [a, d, e] self._channel: {a: 0, b: 1, c: 2} self._channel: {a: 0, d:1, e:2} Allows use of a subject layer of fixed dimension across subjects Prevents building a too large channel dimension when many subjects Example:: `self.channel_order = "unique"` (default behavior) Unique channel stacking: we loop across all mne channels. if this channel is not known, we create a new dimension for it: dimension = len(self._channels) subject1: [a, b, c,]; subject2: [a, d, e] self._channel: {a: 0, b: 1, c: 2} self._channel: {a: 0, b: 1, c: 2, d:3, e:3} """ match self.channel_order: case "original": for i, ch in enumerate(ch_names): self._channels[ch] = i case "unique": for ch in ch_names: if ch not in self._channels: self._channels[ch] = len(self._channels) def _get_channels(self, ch_names: list[str]) -> list[int]: if not self._channels: self._update_channels(ch_names) try: channel_idx = [self._channels[ch] for ch in ch_names] except KeyError as e: msg = f"Channel {e} not found in the channel mapping, likely because " msg += "this dataset contains recordings with different sets of channel " msg += "names. Try calling self.prepare on the whole events dataframe." raise KeyError(msg) from e return channel_idx @staticmethod def _notch_filter( raw: mne.io.Raw, notch_filter: float | list[float], mne_cpus: int ) -> mne.io.Raw: notch_filter = [notch_filter] if isinstance(notch_filter, float) else notch_filter notch_freqs: list[float] = [] for freq in notch_filter: notch_freqs.extend( np.arange(freq, min(raw.info["sfreq"] / 2, 301), freq).tolist() # type: ignore ) if len(notch_freqs) == 0: logger.info("Not applying notch filter as no valid frequencies were found.") else: logger.info("Applying notch filter with notch_freqs=%s", sorted(notch_freqs)) raw = raw.notch_filter( notch_freqs, phase="zero", filter_length="auto", n_jobs=mne_cpus ) return raw
[docs] class MegExtractor(MneRaw): """ MEG feature extractor. Parameters ---------- picks: default = ("meg",) pick "meg" channels by default. """ event_types: tp.Literal["Meg"] = "Meg" picks: str | tuple[str, ...] = pydantic.Field(("meg",), min_length=1)
[docs] class EegExtractor(MneRaw): """ EEG feature extractor. Parameters ---------- picks: default = ("eeg",) pick "eeg" channels by default. """ event_types: tp.Literal["Eeg"] = "Eeg" picks: tuple[str, ...] = pydantic.Field(("eeg",), min_length=1)
[docs] class EmgExtractor(MneRaw): """ EMG feature extractor. Parameters ---------- picks: default = ("emg",) pick "emg" channels by default. """ event_types: tp.Literal["Emg"] = "Emg" picks: tuple[str, ...] = pydantic.Field(("emg",), min_length=1)
[docs] class IeegExtractor(MneRaw): """ Intracranial EEG feature extractor. Parameters ---------- picks: default = ("seeg", "ecog", ) pick "seeg" and "ecog" channels by default. reference: "bipolar" or None, default=None If "bipolar", applies a bipolar reference to the data, i.e., uses neighboring electrode as reference. Uses mne.set_bipolar_reference under the hood. [ieeg1]_ Notes ---------- Bipolar reference currently can only be applied to sEEG. It expects that the channels in raw.ch_names are ordered by probe, and with ascending order for each probe, and the names consists of the probe name followed by the position on the probe. eg: ['OF1', 'OF2', 'OF3', ... , 'OF12', 'OF13', 'OF14', 'H1', 'H2', 'H3', ... , 'H13', 'H14', 'H15', ...] WATCH-OUT: this will take the closest electrode on the probe, meaning that if the neighboring electrode is missing for some reason (eg: rejected before applying the reference) then the next electrode will be used for referencing. References ---------- .. [ieeg1] https://mne.tools/stable/generated/mne.set_bipolar_reference.html """ event_types: tp.Literal["Ieeg"] = "Ieeg" picks: tuple[str, ...] = pydantic.Field( ( "seeg", "ecog", ), min_length=1, ) reference: tp.Literal["bipolar"] | None = None def model_post_init(self, log__): super().model_post_init(log__) if self.reference == "bipolar" and self.picks != ("seeg",): raise ValueError( f"Bipolar reference can only be applied on seeg signals, got picks {self.picks} instead" ) if self.reference == "bipolar" and self.bipolar_ref is not None: raise ValueError( "Cannot use both reference='bipolar' (auto-derived from " "neighboring electrodes) and bipolar_ref (explicit " "anode/cathode lists) at the same time." ) def _preprocess_raw(self, raw: mne.io.Raw, event: etypes.MneRaw) -> MneTimedArray: raw = self._pick_channels(raw, self.picks) if self.drop_bads: raw.load_data() raw = raw.drop_channels(raw.info["bads"]) if self.reference == "bipolar": raw.load_data() raw = self._apply_bipolar_ref(raw) return super()._preprocess_raw(raw, event) def _apply_bipolar_ref(self, raw: mne.io.Raw) -> mne.io.Raw: """ Apply bipolar reference for EEG, i.e., uses neighboring electrode as reference. Parameters ---------- raw : mne.io.Raw Raw instance that will be referenced. Returns ------- raw : mne.io.Raw Referenced Raw object """ logger.info("Applying bipolar reference") logger.warning( "Assumes raw.ch_names are ordered by probe, with ascending order for each probe, and the names consists of the probe name followed by the probe position." ) logger.warning( "WATCH-OUT: Taking the closest electrode on the probe... If the neighboring electrode is missing (e.g., rejected before applying the reference) then the next electrode will be used as the reference." ) reference_ch = list(raw.ch_names) anodes = reference_ch[0:-1] cathodes = reference_ch[1:] to_del = [] # allow constructions that are on the same probe # (e.g., HA 5-HA 6 or HA 5-HA 7 if contact HA 6 is turned off) # don't allow constructions across probes # (e.g., HA 5 with FA 6) for i, (a, c) in enumerate(zip(anodes, cathodes)): if [j for j in a if not j.isdigit()] != [j for j in c if not j.isdigit()]: to_del.append(i) for idx in to_del[::-1]: del anodes[idx] del cathodes[idx] bipol = mne.set_bipolar_reference(raw, anodes, cathodes, verbose="WARNING") return bipol
[docs] class SpikesExtractor(BaseExtractor): """Feature extractor for spike data stored in HDF5/NWB files. Reads spike times from HDF5 files and creates a dense binned array of shape (n_units, n_time_bins) at the specified frequency. The preprocessing steps, if specified, are ordered as follows: 1. Spike binning at target frequency 2. Scaling 3. Baseline correction (applied on segments) 4. Clamp (applied on segments) Parameters ---------- frequency : "native" or float, default="native" Target sampling frequency for spike binning. If ``"native"``, uses the frequency declared in the Spikes event. offset : float, default=0.0 Time offset (in seconds) applied to the segment window. baseline : tuple of float, optional If provided as ``(start, end)``, defines the baseline correction window in seconds relative to the segment start. scaler : {"RobustScaler", "StandardScaler"}, optional Scaling strategy to normalize channel data using scikit-learn scalers. scale_factor : float, optional Multiplicative factor applied to the data after scaling but before clamping. clamp : float, optional Maximum absolute value for clamping after preprocessing. channel_order : {"unique", "original"}, default="unique" ``"unique"``: channels are numbered based on unique names across all recordings. ``"original"``: channel indices follow per-recording order, enabling a fixed-size channel dimension across subjects. """ event_types: tp.Literal["Spikes"] = "Spikes" frequency: tp.Literal["native"] | float = "native" offset: float = 0.0 baseline: tuple[float, float] | None = None scaler: None | tp.Literal["RobustScaler", "StandardScaler"] = None scale_factor: float | None = None clamp: float | None = None channel_order: tp.Literal["unique", "original"] = "unique" requirements: tp.ClassVar[tp.Any] = ("h5py",) infra: MapInfra = MapInfra( timeout_min=120, cpus_per_task=10, version="1", ) _channels: dict[str, int] = {} def _exclude_from_cache_uid(self) -> list[str]: prev = super()._exclude_from_cache_uid() return prev + ["baseline", "offset", "scale_factor", "clamp"] def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) if self.baseline is not None: issue = len(self.baseline) != 2 issue |= not all(isinstance(b, float) for b in self.baseline) issue |= self.baseline[1] <= self.baseline[0] if issue: msg = f"baseline must be None or 2 floats, got {self.baseline}" raise ValueError(msg) def prepare(self, obj: DataframeOrEventsOrSegments) -> None: events: list[etypes.Spikes] events = self._event_types_helper.extract(obj) # type: ignore for ta in self._get_data(events): assert ta.header is not None self._update_channels(ta.header["ch_names"].split(",")) if events: self(events[0], start=events[0].start, duration=0.001, trigger=events[0]) @staticmethod def _bin_spikes(nwb_file: tp.Any, sfreq: float) -> tuple[np.ndarray, list[str]]: """Bin spike times from an open HDF5/NWB file into a dense array. Parameters ---------- nwb_file : h5py.File Open HDF5 file following the NWB spikes convention (``units/id``, ``units/spike_times``, ``units/spike_times_index``). sfreq : float Sampling frequency for binning (Hz). Returns ------- data : np.ndarray of shape (n_units, n_bins) Dense spike-count array. ch_names : list of str One name per unit. """ units_id = nwb_file["units/id"][:] spikes = nwb_file["units/spike_times"][:] spike_times_index = nwb_file["units/spike_times_index"][:] max_time = float(np.max(spikes)) n_bins = int(np.ceil(max_time * sfreq)) data = np.zeros((len(units_id), n_bins), dtype=np.float32) for row_idx in range(len(units_id)): start_idx = 0 if row_idx == 0 else int(spike_times_index[row_idx - 1]) stop_idx = int(spike_times_index[row_idx]) unit_spikes = spikes[start_idx:stop_idx] bin_indices = np.floor(unit_spikes * sfreq).astype(int) bin_indices = np.clip(bin_indices, 0, n_bins - 1) np.add.at(data[row_idx, :], bin_indices, 1) if ( "units/electrodes" in nwb_file and "general/extracellular_ephys/electrodes" in nwb_file ): unit_electrodes = nwb_file["units/electrodes"][:] electrode_ids = nwb_file["general/extracellular_ephys/electrodes/id"][:] ch_names = [ f"unit_{uid}_{electrode_ids[unit_electrodes[i]]}" for i, uid in enumerate(units_id) ] else: ch_names = [f"unit_{uid}" for uid in units_id] return data, [str(n) for n in ch_names] def _preprocess_spikes(self, event: etypes.Spikes) -> TimedArray: import h5py # type: ignore[import-untyped] sfreq = ( float(event.frequency) if self.frequency == "native" else float(self.frequency) ) nwb_file = event.read() try: data, ch_names = self._bin_spikes(nwb_file, sfreq) finally: if isinstance(nwb_file, h5py.File): nwb_file.close() if self.scaler is not None: scaler_cls = getattr(sklearn.preprocessing, self.scaler)() data = scaler_cls.fit_transform(data.T).T header: dict[str, tp.Any] = {"ch_names": ",".join(ch_names)} return TimedArray(data=data, frequency=sfreq, start=event.start, header=header) @infra.apply( item_uid=lambda e: str(e.study_relative_path()), exclude_from_cache_uid="method:_exclude_from_cache_uid", ) def _get_data(self, events: tp.Sequence[etypes.Spikes]) -> tp.Iterator[TimedArray]: for event in events: yield self._preprocess_spikes(event) def _get_timed_arrays( self, events: list[etypes.Spikes], start: float, duration: float ) -> tp.Iterable[TimedArray]: for event in events: yield self._get_timed_array(event, start, duration) def _get_timed_array( self, event: etypes.Spikes, start: float, duration: float ) -> TimedArray: start += self.offset window_start, window_stop = start, start + duration if self.baseline is not None: if self.baseline[0] >= self.baseline[1]: raise RuntimeError(f"unexpected baseline:{self.baseline}") window_start = min(window_start, start + self.baseline[0]) window_stop = max(window_stop, start + self.baseline[1]) ta = next(self._get_data([event])) freq = ta.frequency assert ta.header is not None ch_names = ta.header["ch_names"].split(",") ta = ta.with_start(event.start) tdata = ta.overlap(start=window_start, duration=window_stop - window_start) tdata.data = np.asarray(tdata.data) if self.scale_factor is not None: tdata.data = tdata.data * self.scale_factor if self.baseline is not None: baseline_duration = self.baseline[1] - self.baseline[0] base = tdata.overlap(start + self.baseline[0], baseline_duration).data if base.size: tdata.data = tdata.data - base.mean(1, keepdims=True) tdata = tdata.overlap(start=start, duration=duration) channel_idx = self._get_channels(ch_names) timed_out = TimedArray(frequency=freq, start=start, duration=duration) out_shape = (max(self._channels.values()) + 1, timed_out.data.shape[-1]) out = np.zeros(out_shape, dtype=np.float32) if tdata.start == start and tdata.duration == duration: timed_out = tdata else: timed_out += tdata if self.clamp is not None: timed_out.data = np.clip(timed_out.data, a_min=-self.clamp, a_max=self.clamp) out[channel_idx, :] = timed_out.data timed_out.start -= self.offset timed_out.data = out return timed_out def _update_channels(self, ch_names: list[str]) -> None: match self.channel_order: case "original": for i, ch in enumerate(ch_names): self._channels[ch] = i case "unique": for ch in ch_names: if ch not in self._channels: self._channels[ch] = len(self._channels) def _get_channels(self, ch_names: list[str]) -> list[int]: if not self._channels: self._update_channels(ch_names) try: channel_idx = [self._channels[ch] for ch in ch_names] except KeyError as e: msg = f"Channel {e} not found in the channel mapping, likely because " msg += "this dataset contains recordings with different sets of channel " msg += "names. Try calling self.prepare on the whole events dataframe." raise KeyError(msg) from e return channel_idx
[docs] class FnirsExtractor(MneRaw): """ Functional Near-Infrared Spectroscopy (fNIRS) feature extractor. Provides preprocessing and handling of fNIRS data using functions from mne.preprocessing.nirs, including channel selection, optical density conversion, haemodynamic response computation, and signal enhancement. The order of preprocessing, if specified, is as follows: 1. Channel selection 2. Drop bad channels 3. Exclude channels based on source-detector distance 4. Convert raw intensity data to optical density 5. Exclude channels based on scalp coupling index 6. Apply Temporal Derivative Distribution Repair (TDDR) to reduce motion artifacts 7. Compute haemodynamic responses from optical density 8. Band-pass filtering 9. Enhance negatively correlated signals 10. Resampling 11. Scaling All steps executed with extractor.prepare() Parameters ----- picks: default=("fnirs",) Channels to select for analysis, picks "fnirs" channels by default. distance_threshold: float | None, default=None Minimum source-detector distance for channel selection. Channels with distances below this threshold will be excluded. compute_optical_density: bool, default=False Whether to convert raw intensity data to optical density. scalp_coupling_index_threshold: float | None, default=None Minimum scalp coupling index for channel inclusion. Requires `compute_optical_density=True`. apply_tddr: bool, default=False Whether to apply Temporal Derivative Distribution Repair to reduce motion artifacts. compute_heamo_response: bool, default=False Whether to compute haemodynamic responses from optical density. Requires `compute_optical_density=True`. partial_pathlength_factor: float, default=0.1 Partial pathlength factor used in the Beer-Lambert law for haemodynamic response calculation. enhance_negative_correlation: bool, default=False Whether to enhance negatively correlated signals. Requires `compute_heamo_response=True`. Notes ----- - Filtering, resampling, scaling are applied if `filter` or `frequency` attributes are set (see parent class). """ event_types: tp.Literal["Fnirs"] = "Fnirs" picks: tuple[str, ...] = pydantic.Field(("fnirs",), min_length=1) distance_threshold: float | None = None compute_optical_density: bool = False scalp_coupling_index_threshold: float | None = None apply_tddr: bool = False # Apply temporal derivative distribution repair compute_heamo_response: bool = False partial_pathlength_factor: float = 0.1 enhance_negative_correlation: bool = False requirements: tp.ClassVar[tp.Any] = ("mne-nirs",) def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) # Ensure preprocessing steps are consistent with one another if self.compute_heamo_response and not self.compute_optical_density: msg = "Computing haemodynamic response requires computing optical density first." raise ValueError(msg) if self.scalp_coupling_index_threshold is not None: if not self.compute_optical_density: raise ValueError( "Thresholding with the SCI requires computing optical density first." ) if self.enhance_negative_correlation and not self.compute_heamo_response: msg = "Applying negative correlation enhancement requires haemodynamic responses." raise ValueError(msg) def _preprocess_raw(self, raw: mne.io.Raw, event: etypes.MneRaw) -> MneTimedArray: raw = self._pick_channels(raw, self.picks) if self.drop_bads: raw.load_data() raw = raw.drop_channels(raw.info["bads"]) if self.distance_threshold is not None: dists = mne.preprocessing.nirs.source_detector_distances(raw.info) if np.isnan(dists).any(): msg = "Some or all distances are nan, please fix montage information." raise ValueError(msg) picks = compress(raw.ch_names, dists > self.distance_threshold) raw = raw.pick(list(picks)) if self.compute_optical_density: raw = mne.preprocessing.nirs.optical_density(raw) if self.scalp_coupling_index_threshold is not None: sci = mne.preprocessing.nirs.scalp_coupling_index(raw) picks = compress(raw.ch_names, sci > self.scalp_coupling_index_threshold) raw = raw.pick(list(picks)) if self.apply_tddr: raw = mne.preprocessing.nirs.temporal_derivative_distribution_repair(raw) if self.compute_heamo_response: raw = mne.preprocessing.nirs.beer_lambert_law( raw, ppf=self.partial_pathlength_factor ) if self.filter is not None: raw.load_data() raw.filter( self.filter[0], self.filter[1], n_jobs=self.mne_cpus, verbose=False ) if self.enhance_negative_correlation: import mne_nirs raw = mne_nirs.signal_enhancement.enhance_negative_correlation(raw) if self.frequency not in ("native", event.frequency): raw.load_data() raw = raw.resample(float(self.frequency), n_jobs=self.mne_cpus, verbose=False) if self.scaler is not None: raw.load_data() scaler = getattr(sklearn.preprocessing, self.scaler)() raw._data = scaler.fit_transform(raw._data.T).T return MneTimedArray.from_native(raw)
# --------------------------------------------------------------------------- # FmriExtractor sub-configs # --------------------------------------------------------------------------- class FmriCleaner(pydantic.BaseModel): """Configuration for fMRI signal cleaning. Parameters ---------- standardize : str | bool Standardization method forwarded to ``nilearn.signal.clean``. detrend : bool Whether to remove linear trends. low_pass : float | None Low-pass filter cutoff in Hz. high_pass : float | None High-pass filter cutoff in Hz. filter: tp.Literal["butterworth", "cosine"] | None Filter to use: "butterworth" or "cosine". ensure_finite: bool Whether to set nans to 0. """ model_config = pydantic.ConfigDict(extra="forbid") standardize: tp.Literal["zscore_sample", "zscore", "psc"] | bool = "zscore_sample" detrend: bool = True high_pass: float | None = None low_pass: float | None = None filter: tp.Literal["butterworth", "cosine"] | None = None ensure_finite: bool = True def clean(self, data: np.ndarray, t_r: float) -> np.ndarray: if ( self.detrend or self.standardize or (self.high_pass is not None) or (self.low_pass is not None) ): import nilearn.signal data = data.T # set time as first dim shape = data.shape data = nilearn.signal.clean( data.reshape(shape[0], -1), t_r=t_r, standardize=self.standardize, high_pass=self.high_pass, low_pass=self.low_pass, filter=self.filter, detrend=self.detrend, ensure_finite=self.ensure_finite, ) data = data.reshape(shape).T return data class BaseFmriProjector(DiscriminatedModel, discriminator_key="name"): """Base class for spatial projection sub-configs. Uses ``name`` as the pydantic discriminator key, following the ``NamedModel`` convention used throughout the codebase. Subclasses must implement :meth:`apply`. """ def apply(self, rec: tp.Any, **kwargs: tp.Any) -> np.ndarray: """Project a recording to the target space. Parameters ---------- rec NIfTI-like recording (4-D volumetric or 2-D surface). **kwargs Subclass-specific options (e.g. ``standardize`` for atlas maskers). Returns ------- np.ndarray Projected data with shape ``(n_features, time)``. """ raise NotImplementedError class SurfaceProjector(BaseFmriProjector): """Project data to an fsaverage surface mesh. For volumetric data, this uses ``nilearn.surface.vol_to_surf`` to project the data to the surface. For surface data, this simply downsamples the data to the target mesh resolution. Fields beyond ``mesh`` mirror the keyword arguments of ``nilearn.surface.vol_to_surf`` and are forwarded to it. Examples -------- >>> SurfaceProjector(mesh="fsaverage5") >>> SurfaceProjector(mesh="fsaverage6", radius=5.0, interpolation="nearest") """ mesh: str radius: float = 3.0 interpolation: tp.Literal["linear", "nearest"] = "linear" kind: tp.Literal["auto", "line", "ball"] = "auto" n_samples: int | None = None mask_img: tp.Any | None = None depth: list[float] | None = None _mesh: tp.Any | None = pydantic.PrivateAttr(default=None) def model_post_init(self, __context: tp.Any) -> None: super().model_post_init(__context) if self.mesh not in FSAVERAGE_SIZES: raise ValueError(f"mesh must be an fsaverage mesh (got {self.mesh!r})") def get_mesh(self) -> tp.Any: if self._mesh is None: from nilearn import datasets fsaverage = datasets.fetch_surf_fsaverage(self.mesh) self._mesh = fsaverage return self._mesh def apply(self, rec: tp.Any) -> np.ndarray: if len(rec.shape) == 4: # 4-D volume data → use nilearn.surface.vol_to_surf mesh = self.get_mesh() from nilearn.surface import vol_to_surf hemis = [ vol_to_surf( rec, surf_mesh=mesh[f"pial_{hemi}"], inner_mesh=mesh[f"white_{hemi}"], radius=self.radius, interpolation=self.interpolation, kind=self.kind, n_samples=self.n_samples, mask_img=self.mask_img, depth=self.depth, ) for hemi in ("left", "right") ] return np.vstack(hemis) elif len(rec.shape) == 2: # 2-D surface data → downsample to target mesh resolution n_vertices = rec.shape[0] // 2 if n_vertices not in list(FSAVERAGE_SIZES.values()) or rec.shape[0] % 2: msg = f"The detected number of vertices ({rec.shape[0]}) is not in {list(FSAVERAGE_SIZES.values())}" raise ValueError(msg) n_vertices_resampled = FSAVERAGE_SIZES.get(self.mesh) data = rec.get_fdata() if n_vertices < n_vertices_resampled: raise NotImplementedError( f"Cannot upsample from {n_vertices} vertices to {n_vertices_resampled} vertices" ) if n_vertices > n_vertices_resampled: left = data[:n_vertices_resampled, :] right = data[n_vertices : n_vertices + n_vertices_resampled, :] data = np.concatenate([left, right], axis=0) return data else: raise ValueError( f"Unexpected shape {rec.shape} (should have 2 or 4 dimensions)" ) class MaskProjector(BaseFmriProjector): """Apply a mask to volumetric images. This converts the volumetric data of shape [x, y, z, time] to [n_voxels, time]. Parameters ---------- mask: str Mask to load from nilearn.datasets. resolution: int Resolution of the mask in millimeters. Must be 1 or 2. """ mask: tp.Literal["mni152_brain", "mni152_gm", "mni152_wm", "subcortical"] = ( "mni152_gm" ) resolution: tp.Literal[1, 2] = 2 _mask: tp.Any | None = pydantic.PrivateAttr(default=None) def get_mask(self) -> tp.Any: import nibabel as nib from nilearn import datasets if self._mask is None: if self.mask == "subcortical": atlas = datasets.fetch_atlas_harvard_oxford( f"sub-maxprob-thr50-{self.resolution}mm" ) excluded = ["Cortex", "White", "Stem", "Background"] selected_indices = [ i for i, label in enumerate(atlas.labels) if any([exc.lower() in label.lower() for exc in excluded]) ] mask_data = atlas.maps.get_fdata() mask_data[np.isin(mask_data, selected_indices)] = 0 mask = nib.Nifti1Image(mask_data, atlas.maps.affine, atlas.maps.header) else: mask_func = getattr(datasets, f"load_{self.mask}_mask") mask = mask_func(resolution=self.resolution) self._mask = mask return self._mask def apply(self, rec: tp.Any, **kwargs: tp.Any) -> np.ndarray: if len(rec.shape) != 4: raise ValueError(f"Mask projection requires 4D data, got {rec.shape}") from nilearn.image import resample_to_img mask = self.get_mask() rec = resample_to_img(rec, mask, copy_header=True) return rec.get_fdata()[mask.get_fdata() > 0, :] class AtlasProjector(BaseFmriProjector): """Project to atlas parcels via a nilearn masker. This converts volumetric data of shape [x, y, z, time] to [n_parcels, time], where each parcel is a weighted sum of voxels from the volumetric data. Parameters ---------- atlas : str Atlas name; must match a ``nilearn.datasets.fetch_atlas_{atlas}`` function. atlas_kwargs : dict | None Keyword arguments forwarded to the nilearn fetch function. Validated against the function signature at init time. Examples -------- >>> AtlasProjector(atlas="schaefer_2018", atlas_kwargs={"n_rois": 400}) >>> AtlasProjector(atlas="difumo", atlas_kwargs={"dimension": 64}) """ atlas: str atlas_kwargs: dict[str, tp.Any] | None = None _masker: tp.Any = pydantic.PrivateAttr(default=None) def model_post_init(self, __context: tp.Any) -> None: super().model_post_init(__context) import nilearn.datasets as nds func_name = f"fetch_atlas_{self.atlas}" if not hasattr(nds, func_name): raise ValueError(f"Atlas {self.atlas!r} not found in nilearn.datasets") if self.atlas_kwargs is not None: func = getattr(nds, func_name) params = inspect.signature(func).parameters for k in self.atlas_kwargs: if k not in params: raise ValueError(f"Invalid {func_name} argument: {k!r}") def get_atlas(self) -> tp.Any: from nilearn import datasets atlas_func = getattr(datasets, f"fetch_atlas_{self.atlas}") atlas = atlas_func(**(self.atlas_kwargs or {})) return atlas def _get_masker( self, ) -> tp.Any: """Build (and cache) the nilearn atlas masker.""" if self._masker is None: from nilearn.maskers import NiftiLabelsMasker, NiftiMapsMasker maps = self.get_atlas().maps if isinstance(maps, str): import nibabel as nib maps = nib.load(maps) if maps.get_fdata().ndim == 3: # deterministic atlas masker = NiftiLabelsMasker( labels_img=maps, ) else: # probabilistic atlas masker = NiftiMapsMasker(maps_img=maps) self._masker = masker return self._masker def apply(self, rec: tp.Any) -> np.ndarray: if len(rec.shape) != 4: raise ValueError(f"Atlas projection requires 4D data, got {rec.shape}") masker = self._get_masker() return masker.fit_transform(rec).T class FoscoProjector(BaseFmriProjector): """Subset HCP grayordinates to FOSCO visual ROIs. Applies a spatial mask using ``hcp_utils`` to keep only vertices belonging to the FOSCO ROI set (41 visual areas from the HCP MMP1.0 parcellation). Expects 2-D input of shape ``(nodes, time)``. Examples -------- >>> FoscoProjector() Config usage:: "neuro": { "name": "FmriExtractor", "projection": {"name": "FoscoProjector"}, "from_space": "atlas_msmall", } """ ROIS: tp.ClassVar[list[str]] = [ "V1", "MST", "V6", "V2", "V3", "V4", "MT", "V8", "V3A", "RSC", "POS2", "V7", "IPS1", "FFC", "V3B", "LO1", "LO2", "PIT", "PCV", "STV", "7m", "POS1", "23d", "v23ab", "d23ab", "31pv", "LIPv", "VIP", "MIP", "PH", "TPOJ1", "TPOJ2", "TPOJ3", "IP2", "IP1", "IP0", "VMV1", "VMV3", "LO3", "VMV2", "VVC", ] _mask: np.ndarray | None = pydantic.PrivateAttr(default=None) def _get_mask(self) -> np.ndarray: if self._mask is None: import hcp_utils as hcp # type: ignore[import-not-found] self._mask = np.isin( hcp.mmp.map_all, [ k for k, v in hcp.mmp.labels.items() if (v.startswith("R_") or v.startswith("L_")) and v[2:] in self.ROIS ], ) return self._mask def apply(self, rec: tp.Any, **kwargs: tp.Any) -> np.ndarray: data = rec.get_fdata() if data.ndim != 2: raise ValueError( f"FoscoProjector expects 2D (nodes, time) data, got {data.ndim}D" ) return data[self._get_mask(), :] # --------------------------------------------------------------------------- # FmriExtractor # ---------------------------------------------------------------------------
[docs] class FmriExtractor(BaseExtractor): """fMRI feature extraction with optional projection, signal cleaning, and caching to a NumPy memmap. Input: a volumetric image of shape [x, y, z, time] or a surface image of shape [n_vertices, time]. Preprocessing pipeline (each step is optional): 1. **Spatial smoothing** (``fwhm``): If active, smooths the image with an isotropic Gaussian kernel. 2. **Spatial projection** (``projection``): If active, projects the data to an array [n_features, time]. There are three projection types: SurfaceProjector (-> [n_vertices, time]), MaskProjector (-> [n_voxels, time]), and AtlasProjector (-> [n_parcels, time]). 3. **Signal cleaning** (``cleaning``): If active, cleans the data using ``nilearn.signal.clean``. 4. **Temporal resampling** (``frequency``): If active, resamples the data to the target frequency, using np.interp. Parameters ---------- offset : float Seconds to shift TRs forward to align delayed BOLD response. projection : BaseFmriProjector | None Spatial projection config — one of: * ``SurfaceProjector(mesh="fsaverage5")`` * ``MaskProjector(mask="mni152_gm", resolution=2)`` * ``AtlasProjector(atlas="schaefer_2018", atlas_kwargs={"n_rois": 400})`` * ``None`` — keep raw volumetric / surface data cleaning : FmriCleaner | None Signal cleaning config, passed to ``nilearn.signal.clean``. ``None`` skips all cleaning. frequency : ``"native"`` | float Target sampling frequency. padding : int | ``"auto"`` | None Pad 1-D+T data to a uniform voxel count across subjects. from_space : str | ``"auto"`` | None Input space to load. ``None`` (default) passes through when only one space is present, raises when multiple are found. ``"auto"`` uses a projection-aware heuristic. An explicit string selects that space. from_preproc : str | tuple[str, ...] | None Filter events by preprocessing pipeline. fwhm : float | None Full width at half maximum (in mm) for isotropic spatial smoothing via ``nilearn.image.smooth_img``. Applied after masking and before projection. ``None`` skips smoothing. """ requirements: tp.ClassVar[tp.Any] = ("nilearn",) offset: float = 0.0 event_types: tp.Literal["Fmri"] = "Fmri" projection: BaseFmriProjector | None = None cleaning: FmriCleaner | None = FmriCleaner() frequency: tp.Literal["native"] | float = "native" padding: int | tp.Literal["auto"] | None = None from_space: str | None = None from_preproc: str | tuple[str, ...] | None = None fwhm: float | None = None infra: MapInfra = MapInfra( timeout_min=120, cpus_per_task=10, version="2", ) _padding: int | None = None def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) if isinstance(self.padding, int): self._padding = self.padding if not self.projection and self.infra.folder is not None: logger.warning( f"{self.name}: volumetric fMRI (with projection=None) should not be cached: preprocessing should be fast, and caching will use too much disk space. Please set infra.folder=None" ) def _exclude_from_cache_uid(self) -> list[str]: return super()._exclude_from_cache_uid() + ["offset", "padding"] def _auto_filter_fmri_events( self, fmri_events: list[etypes.Fmri] ) -> list[etypes.Fmri]: """Filter fMRI events by ``from_preproc``, ``from_space``, and ``projection``. 1. **Preproc filter** -- drop events not matching ``self.from_preproc`` (skipped when ``None``). 2. **Space selection**: - Explicit string (e.g. ``"MNI152NLin2009cAsym"``) -- keep only that space; raises if no events match. - ``"auto"`` -- projection-aware heuristic: ``SurfaceProjector`` prefers fsaverage meshes, other projectors prefer MNI spaces. Requires ``self.projection`` to be set. - ``None`` (default) -- pass through if one space, raise if multiple. """ from_preproc = self.from_preproc from_space = self.from_space if isinstance(from_preproc, str): from_preproc = (from_preproc,) if from_preproc is not None: available_preprocs = {e.preproc for e in fmri_events} missing = set(from_preproc) - available_preprocs if missing: raise ValueError( f"from_preproc={missing} not found. Available: {available_preprocs}" ) fmri_events = [e for e in fmri_events if e.preproc in from_preproc] preproc_ctx = ( f" (after from_preproc={self.from_preproc!r} filter)" if from_preproc is not None else "" ) spaces = {e.space for e in fmri_events} if len(spaces) > 1 and from_space is None: logger.warning( "Multiple fMRI spaces found: %s. If you trigger list_segments " "on Fmri events, each space produces a separate segment. " "Filter with from_space or a QueryEvents transform " "before segmentation to avoid duplicates.", spaces, ) if isinstance(from_space, str) and from_space != "auto": filtered = [e for e in fmri_events if e.space == from_space] if not filtered: raise ValueError( f"from_space={from_space!r} matched no events. " f"Available spaces: {spaces}{preproc_ctx}" ) return filtered if len(spaces) <= 1: return fmri_events if from_space is None: raise ValueError( f"Multiple spaces found ({spaces}), set from_space explicitly " "or use from_space='auto'" ) # from_space == "auto": projection-aware heuristic if self.projection is None: raise ValueError( f"from_space='auto' requires a projection to choose among " f"{spaces}. Set from_space to an explicit space name, or " "configure a projection." ) fs = "fsaverage" if isinstance( self.projection, SurfaceProjector ) and self.projection.mesh.startswith(fs): candidates = [f"{fs}{n}" for n in range(3, 8)] + [fs] else: candidates = ["MNI152NLin2009cAsym"] + [s for s in spaces if "MNI" in s] best = next((c for c in candidates if c in spaces), None) if best is None: raise ValueError( f"No matching space for projection={self.projection!r} among " f"{spaces}. Set from_space to an explicit space name." ) return [e for e in fmri_events if e.space == best] def prepare(self, obj: DataframeOrEventsOrSegments) -> None: all_events: list[etypes.Fmri] = self._event_types_helper.extract(obj) # type: ignore[assignment] events = self._auto_filter_fmri_events(all_events) if self.padding == "auto": # for padding, we first need everything to be preprocessed # but we need on an object without padding since missing_default filling # will apply the extractor on 1 event, and we'll need the padding length for that self.infra.clone_obj(padding=None).prepare(events) self._padding = max(ta.data.shape[0] for ta in self._get_data(events)) # type: ignore # (recompute prepare to just fill the missing default value) super().prepare(events) @staticmethod def _resample_trs( data: np.ndarray, orig_tr: float, new_tr: float ) -> tuple[np.ndarray, np.ndarray]: """Resample TRs (i.e. in the time dimension) using np.interp. Since np.interp does not extrapolate, we cannot resample beyond the original time range. This means we likely lose one TR (or a few) at the very end of the timeline. These "invalid" TRs are filled with NaNs to keep the output data shape consistent with the original duration. '""" voxel_dims, n_trs = data.shape[:-1], data.shape[-1] trs = np.arange(n_trs) * orig_tr new_trs = np.arange(0, trs[-1] + orig_tr - new_tr, step=new_tr) flat_data = data.reshape(-1, n_trs) resamp_data = np.stack( [np.interp(new_trs, trs, voxel, right=voxel[-1]) for voxel in flat_data], axis=0, ) resamp_data = resamp_data.reshape(*voxel_dims, -1) # Evaluate how many TRs were dropped at the end because of no extrapolation n_new_trs_without_extrapolation = int(trs[-1] / new_tr) + 1 dropped_trs = len(new_trs) - n_new_trs_without_extrapolation if dropped_trs > 0: msg = ( f"{dropped_trs} invalid TR(s) at the end of the timeline due to resampling without" f" interpolation from T={orig_tr:0.2f}s to T={new_tr:0.2f}s. These invalid TRs " "have been filled with NaNs." ) logger.warning(msg) return resamp_data, new_trs def _get_relevant_events( self, events: tp.Any, trigger: etypes.Event | None ) -> list[etypes.Event]: """Filter Fmri events by preproc/space, then delegate aggregation to base.""" fmri_events = self._auto_filter_fmri_events( self._event_types_helper.extract(events), # type: ignore[arg-type] ) return super()._get_relevant_events(fmri_events, trigger) def _preprocess_event(self, event: etypes.Fmri) -> FmriTimedArray: rec = event.read() space_dims = 3 if len(rec.shape) == 4 else 1 if space_dims == 3 and event.mask_filepath is not None: rec = utils.get_masked_bold_image(rec, event.read_mask()) # --- spatial smoothing --- if self.fwhm is not None: if space_dims != 3: raise ValueError( f"Spatial smoothing (fwhm={self.fwhm}) requires volumetric " f"(4-D) data, but got {len(rec.shape)}-D input." ) from nilearn.image import smooth_img rec = smooth_img(rec, self.fwhm) # --- spatial projection --- header: dict[str, tp.Any] = {"preproc": event.preproc, "space": event.space} if self.projection is not None: data = self.projection.apply(rec) if isinstance(self.projection, SurfaceProjector): header["space"] = self.projection.mesh else: data = rec.get_fdata() if space_dims == 3: header["affine"] = np.array(rec.affine, dtype=np.float64) # --- signal cleaning --- if self.cleaning is not None: data = self.cleaning.clean(data, t_r=1 / event.frequency) # --- temporal resampling --- if self.frequency not in ("native", event.frequency): orig_tr = 1 / event.frequency new_tr = 1 / float(self.frequency) data, _ = self._resample_trs(data, orig_tr, new_tr) freq = event.frequency if self.frequency == "native" else self.frequency return FmriTimedArray( data=data.astype(np.float32), frequency=freq, start=event.start, duration=event.duration, header=header, ) @infra.apply( item_uid=lambda e: e._splittable_event_uid(), exclude_from_cache_uid=_exclude_from_cache_uid, ) def _get_data(self, events: list[etypes.Fmri]) -> tp.Iterable[TimedArray]: for event in tqdm(events, disable=len(events) < 2, desc="Computing fmri data"): yield self._preprocess_event(event) def _get_timed_arrays( self, events: list[etypes.Fmri], start: float, duration: float ) -> tp.Iterable[TimedArray]: if self.padding == "auto" and self._padding is None: raise RuntimeError("Fmri.prepare needs to be called to compute auto padding") for _event, ta in zip(events, self._get_data(events)): data = ta.data if self._padding is not None: if data.ndim != 2: raise ValueError(f"Only 1D+T FMRI can be padded, got {data.shape=}") padding = self._padding - data.shape[0] if padding < 0: raise ValueError( f"Padding to length {self._padding} but got {data.shape=}" ) data = np.pad(data, [(0, self._padding - data.shape[0]), (0, 0)]) yield FmriTimedArray( data=data, frequency=ta.frequency, start=ta.start - self.offset, duration=ta.duration, header=ta.header, )
[docs] class ChannelPositions(BaseStatic): """Channel positions in 2D or 3D, extracted from a Raw object's mne.Info. 3D positions (``n_spatial_dims=3``) are always returned in MNE's **head coordinate frame**, which is defined by the LPA, RPA, and nasion fiducial landmarks (origin at the midpoint of LPA–RPA, x-axis toward RPA, y-axis toward nasion, z-axis upward). This holds regardless of whether positions come from a named standard montage or from the raw data's channel locations. See https://mne.tools/stable/documentation/implementation.html#coordinate-systems for details. Parameters ---------- neuro: Extractor that defines the preprocessing steps applied to the Raw objects. This can either be specified in the config, or built with the `build` method. n_spatial_dims: int Number of spatial dimensions (i.e. coordinates) to extract for each channel. For `n_spatial_dims=2`, the 2D projection of the channel positions as obtained through `mne.Layout` will be used. For `n_spatial_dims=3`, the 3D positions are extracted from `mne.Montage` in head coordinate frame. layout_or_montage_name: Name of the Layout or Montage to use. See `mne.channels.read_layout()` for a list of valid layouts and `mne.channels.get_builtin_montages()` for standard montages. If not provided, the function will look for a layout in the `Raw.info` object or for a montage in the `Raw` object. **Note**: MNE's standard montages are only for EEG systems; MEG montages must be loaded from the raw data. include_ref_eeg: If True, additionally try to extract the position of the anode of bipolar EEG channel (e.g. for the channel name "P3-Cz", return position of both "P3" and "Cz"), yielding and output of shape (n_channels, n_spatial_dims * 2). If True, `event_types` must be one of Eeg or Ieeg. normalize: If True, min-max normalize channel positions between 0 and 1 across each dimension. If False, 2D positions are in arbitrary units given by the mne.Layout projection, while 3D positions will be in head coordinate frame (approximately in the range [-0.1, 0.1] meters). factor: Factor to scale the channel positions by. E.g. set it to 10.0 to get 3D coordinates in decimeters, which yields values approximately in the range [-1, 1]. """ event_types: tp.Literal["MneRaw", "Meg", "Eeg", "Ieeg"] = "MneRaw" neuro: MneRaw | None = None n_spatial_dims: tp.Literal[2, 3] = 2 layout_or_montage_name: str | None = None include_ref_eeg: bool = False normalize: bool = True factor: float = 1.0 _neuro: MneRaw = pydantic.PrivateAttr() # Value to use for channels that are not found in the layout INVALID_VALUE: tp.ClassVar[float] = -0.1 infra: MapInfra = MapInfra() def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) if self.neuro is not None: if self.event_types not in {"MneRaw", self.neuro.event_types}: msg = f"event_types={self.event_types} must match " msg += f"neuro.event_types={self.neuro.event_types}." raise ValueError(msg) eegs = {"Eeg", "Ieeg"} if self.include_ref_eeg and self.neuro.event_types not in eegs: msg = "include_ref_eeg=True is only supported for events_types " msg += f"Eeg and Ieeg, got {self.event_types}." raise ValueError(msg) if self.n_spatial_dims == 2 and isinstance(self.neuro, MegExtractor): raise NotImplementedError( "n_spatial_dims=2 is not supported for MEG data. " ) self._neuro = self.neuro def build(self, neuro: MneRaw) -> "ChannelPositions": config = self.model_dump() config["neuro"] = neuro if config.get("include_ref_eeg") and neuro.event_types not in {"Eeg", "Ieeg"}: logger.warning( "include_ref_eeg=True is not supported for event_types=%s; " "overriding to False.", neuro.event_types, ) config["include_ref_eeg"] = False return self.__class__(**config) def prepare(self, obj: DataframeOrEventsOrSegments) -> None: events = self._event_types_helper.extract(obj) if not hasattr(self, "_neuro"): raise ValueError( "The neuro extractor is not set. Either set it in the config or call build." ) self._neuro.prepare(events) # Ensure the Raw objects have been precomputed super().prepare(events) def _get_layout_positions(self, ta: MneTimedArray) -> dict[str, list[float]]: if self.layout_or_montage_name is not None: layout = mne.channels.read_layout(self.layout_or_montage_name) else: try: layout = mne.find_layout(ta.to_info()) except RuntimeError as err: msg = "No valid layout found. Please specify a layout to load with argument " msg += "`layout_name` or explicitly set a montage in the study class (e.g. with " msg += "`raw.set_montage()`)." raise ValueError(msg) from err mapping = {name: pos[:2].tolist() for name, pos in zip(layout.names, layout.pos)} return mapping def _get_positions_from_ch_locs(self, ta: MneTimedArray) -> dict[str, list[float]]: """Extract 3D channel positions from cached ch_locs, filtering out zero positions.""" ch_locs: np.ndarray = ta.header["ch_locs"] mapping: dict[str, list[float]] = {} for i, name in enumerate(ta.ch_names): loc = np.asarray(ch_locs[i, :3]) # [0, 0, 0] means the position is unknown. This shouldn't happen for # MEG sensors or when ch_locs are properly set, but we filter # defensively for robustness. if not np.all(loc == 0): mapping[name] = loc.tolist() if not mapping: raise RuntimeError( "No valid channel positions found in the cached TimedArray header. " "Please specify a layout_or_montage_name." ) return mapping def _get_montage_positions(self, ta: MneTimedArray) -> dict[str, list[float]]: if self.layout_or_montage_name is not None: montage = mne.channels.make_standard_montage(self.layout_or_montage_name) native_head_t = mne.channels.compute_native_head_t(montage) ch_pos = montage.get_positions()["ch_pos"] # Standard montages shouldn't contain [0, 0, 0] positions, but we # filter defensively for consistency with _get_positions_from_ch_locs. mapping = { name: mne.transforms.apply_trans(native_head_t["trans"], pos).tolist() for name, pos in ch_pos.items() if not np.all(pos == 0) } if not mapping: raise RuntimeError( "No valid montage positions found (all positions were zero). " "Please check the montage name." ) return mapping return self._get_positions_from_ch_locs(ta) def _compute_positions(self, ta: MneTimedArray) -> torch.Tensor: """Get scaled channel positions from a cached MneTimedArray. Returns ------- torch.Tensor : Positions for each channel, of shape (n_channels, n_spatial_dims). When including reference channel (self.include_ref_eeg is True), output shape is (n_channels, n_spatial_dims * 2) where each row contains the coordinates of the cathode channel followed by the coordinates of the anode. """ pos_mapping: dict[str, list[float]] = {} is_meg = isinstance(self._neuro, MegExtractor) if self.n_spatial_dims == 2: pos_mapping = self._get_layout_positions(ta) elif self.n_spatial_dims == 3: if is_meg: pos_mapping = self._get_positions_from_ch_locs(ta) else: pos_mapping = self._get_montage_positions(ta) ta_ch_names = ta.ch_names ch_names: list[str] = [] for ch_name in ta_ch_names: if self.include_ref_eeg: parts = ch_name.split("-", 1) ch_names.append(parts[0]) ch_names.append(parts[1] if len(parts) == 2 else "") else: ch_names.append(ch_name.split("-")[0]) valid_inds = [i for i, n in enumerate(ch_names) if n in pos_mapping] invalid_names = [n for n in ch_names if n and n not in pos_mapping] if not valid_inds: raise ValueError(f"No channel has valid positions: {ta_ch_names}.") if len(valid_inds) < 0.1 * len(ch_names): unique_invalid_names = set(invalid_names) msg = f"Fewer than 10% of the channels have valid positions: {unique_invalid_names}." logger.warning(msg) positions = np.array( [ ( pos_mapping[name] if name in pos_mapping else [np.nan] * self.n_spatial_dims ) for name in ch_names ] ) if self.normalize: ptp = np.nanmax(positions, axis=0, keepdims=True) - np.nanmin( positions, axis=0, keepdims=True ) if (ptp == 0.0).any(): ptp[ptp == 0.0] = 1.0 positions = (positions - np.nanmin(positions, axis=0, keepdims=True)) / ptp positions *= self.factor positions = np.nan_to_num(positions, nan=self.INVALID_VALUE) n_spatial_dims = self.n_spatial_dims if self.include_ref_eeg: n_spatial_dims *= 2 # type: ignore positions = positions.reshape(len(ta_ch_names), n_spatial_dims) channel_idx = self._neuro._get_channels(ta_ch_names) out = torch.full((len(self._neuro._channels), n_spatial_dims), self.INVALID_VALUE) out[channel_idx, :] = torch.from_numpy(positions).float() return out def _exclude_from_cache_uid(self) -> list[str]: ex = super()._exclude_from_cache_uid() if not hasattr(self, "_neuro"): raise RuntimeError("Should not happen") neuro_ex = self._neuro._exclude_from_cache_uid() return ex + [f"neuro.{n}" for n in neuro_ex] @infra.apply( item_uid=lambda e: str(e.study_relative_path()), exclude_from_cache_uid="method:_exclude_from_cache_uid", ) def _get_data(self, events: list[etypes.MneRaw]) -> tp.Iterator[torch.Tensor]: if not hasattr(self, "_neuro"): raise ValueError( "The neuro extractor is not set. Either set it in the config or call build." ) for ta in self._neuro._get_data(events): yield self._compute_positions(ta) def get_static(self, event: etypes.MneRaw) -> torch.Tensor: return next(self._get_data([event]))
[docs] class HrfConvolve(BaseExtractor): """ Convolve the output of an extractor by the Hemodynamic Response Function. Note that this extractor does not support timelines with events.start < 0. Note that this is stored by timeline. If the events change in a timeline, e.g. by using different transform with a similar, then there will be a silent bug. Parameters ---------- extractor: BaseExtractor the extractor used for feature extraction frequency: Literal["native"] The frequency of the cropped extractor. Must be "native". Never used """ event_types: str | tuple[str, ...] = "Event" extractor: BaseExtractor offset: float = 0 duration: pydantic.PositiveFloat | None = None frequency: pydantic.PositiveFloat infra: MapInfra = MapInfra(keep_in_ram=True) aggregation: tp.Literal["mean"] = "mean" _device: tp.Literal["auto", "cuda", "cpu"] = "auto" requirements: tp.ClassVar[tuple[str, ...]] = ("nilearn",) def model_post_init(self, log__: tp.Any) -> None: if self.infra.keep_in_ram is False and self.infra.folder is None: msg = "HrfConvolve requires a cache (folder=my/cache or keep_in_ram=True)." raise ValueError(msg) if isinstance(self.extractor, BaseStatic) and self.extractor.frequency == 0.0: raise ValueError( "HrfConvolve cannot crop a static extractor as it is timeless." ) self.event_types = self.extractor.event_types super().model_post_init(log__) def prepare(self, obj: tp.Any) -> None: from neuralset.events.utils import extract_events # Prepare low-level extractors self.extractor.prepare(obj) # Group events by timeline events = extract_events(obj, types=self._event_types_helper) timelines_dict = defaultdict(list) for event in events: timelines_dict[event.timeline].append(event) timelines = list(timelines_dict.values()) # Cache prepare each hrf-convolved timeline self._get_data(timelines) # Populate shape self(timelines[0], 0, duration=0.001, trigger=timelines[0][0]) # TODO store somewhere a hash of the timeline events to avoid # silent bugs, e.g. on subset of the timeline events? @infra.apply(item_uid=lambda tl_events: tl_events[0].timeline) def _get_data(self, timelines: list[list[etypes.Event]]) -> tp.Iterator[TimedArray]: for events in timelines: if any([ev.start < 0 for ev in events]): raise ValueError("HrfConvolve does not support events with start<0") if len(set([ev.timeline for ev in events])) > 1: raise ValueError("events should be grouped by timeline.") if self.extractor._effective_frequency is None: raise ValueError("hrf.extractor must be prepared.") for events in timelines: from nilearn.glm.first_level.hemodynamic_models import spm_hrf # Get timeline duration duration = 0.0 for event in events: event_stop = event.start + event.duration if event_stop > duration: duration = event_stop # get all dynamic extractors TIME_LENGTH = 32.0 feature_data = self.extractor( events, 0, duration + TIME_LENGTH, trigger=None ).float() # shapes *batch_dims, n_times = feature_data.shape feature_flat = feature_data.reshape(-1, n_times) # (N, n_times) # compute HRF (numpy array) dt = 1.0 / (self.extractor._effective_frequency) hrf = spm_hrf(dt, oversampling=1, time_length=TIME_LENGTH) # hrf amplitude kernel = torch.from_numpy(hrf).float().view(1, 1, len(hrf)) kernel = torch.flip(kernel, dims=[2]) # pytorch compute corr not conv device = self._device if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" # full convolution: pad on left by len(hrf) - 1 conv = torch.nn.functional.conv1d( feature_flat.unsqueeze(1).to(device), # (N, 1, n_times) kernel.to(device), padding=len(hrf) - 1, ) # (N, 1, n_times + len(hrf) - 1) # crop back to original time data = conv.squeeze(1) # (N, n_times) # restore original batch dims yield TimedArray( frequency=self.frequency, start=0, data=data.reshape(*batch_dims, -1).cpu().numpy(), ) def _get_timed_arrays( self, events: list[etypes.Event], start: float, duration: float ) -> tp.Iterator[TimedArray]: if self.extractor._effective_frequency is None: raise ValueError("You must first hrf.prepare(events).") yield from self._get_data([events])