Source code for neuralset.base

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import ast
import collections
import importlib.util
import logging
import pprint
import typing as tp
from pathlib import Path

import exca
import exca.steps
import numpy as np
import pydantic
import yaml
from exca.cachedict import DumpContext

PathLike = str | Path


# # # # # CONFIGURE LOGGER # # # # #
logger = logging.getLogger("neuralset")
_handler = logging.StreamHandler()
_formatter = logging.Formatter(
    "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S"
)
_handler.setFormatter(_formatter)
logger.addHandler(_handler)
logger.setLevel(logging.INFO)
# # # # # CONFIGURED LOGGER # # # # #


def _int_cast(v: tp.Any) -> tp.Any:
    """casts integers to string"""
    if isinstance(v, int):
        return str(v)
    return v


# type hint for casting integers to string
# this is useful for subject field which can be automatically converted from
# str to int by pandas
StrCast = tp.Annotated[str, pydantic.BeforeValidator(_int_cast)]


def _coerce_spec(v: tp.Any) -> str:
    if isinstance(v, str):
        return v
    if isinstance(v, dict):
        for key, val in v.items():
            if not isinstance(key, str) or not isinstance(val, str):
                raise TypeError(
                    f"spec keys and values must be strings, got {key!r}: {val!r}"
                )
        return "&".join(f"{k}={v[k]}" for k in sorted(v))
    raise TypeError(f"spec must be a dict[str, str], got {type(v).__name__}")


FmriSpec = tp.Annotated[str, pydantic.BeforeValidator(_coerce_spec)]
"""Variant parameters encoded as sorted ``key=value&...`` string.

Accepts a ``dict[str, str]`` on input and auto-encodes it::

    spec={"registration": "msmall", "resolution": "1.6mm"}
    # stored as "registration=msmall&resolution=1.6mm"
"""


def validate_query(query: str) -> str:
    """Ensure a pandas query string is a valid Python expression.

    Examples
    --------
    >>> validate_query("index >= 0")
    'index >= 0'
    >>> validate_query("type in ['Word', 'Audio']")
    "type in ['Word', 'Audio']"

    Raises
    ------
    ValueError
        If the query is not valid Python expression syntax, or if it
        contains ``@`` references to local variables (fragile when the
        query is defined far from its execution site).
    """
    try:
        ast.parse(query, mode="eval")
    except SyntaxError as exc:
        raise ValueError(f"Query '{query}' is not valid: {exc}") from exc
    if "@" in query:
        raise ValueError(
            f"Query '{query}' uses '@' to reference local variables. "
            "This is not supported because the query may be evaluated in a "
            "different scope from where it was defined."
        )
    return query


Query = tp.Annotated[str, pydantic.AfterValidator(validate_query)]
"""Validated pandas query string.

Use as a Pydantic field type to automatically check syntax on construction::

    class MyModel(pydantic.BaseModel):
        query: Query               # required valid query
        query: Query | None = None # optional query, None means "no filter"
"""

CACHE_FOLDER = Path.home() / ".cache/neuralset/"
CACHE_FOLDER.mkdir(parents=True, exist_ok=True)


class NamedModel(exca.helpers.DiscriminatedModel, discriminator_key="name"):
    @property
    def name(self) -> str:  # for compatibility
        return self.__class__.__name__


[docs] class BaseModel(pydantic.BaseModel): """Base pydantic model with extra=forbid and nicer print""" model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") def __repr__(self) -> str: data = self.model_dump() return f"{self.__class__.__name__}(**\n{pprint.pformat(data, indent=2)}\n)"
class _Module(BaseModel): requirements: tp.ClassVar[tuple[str, ...]] = () @classmethod def _check_requirements(cls) -> None: """Verify that all packages listed in ``requirements`` are importable. Raises ``ModuleNotFoundError`` with a copy-pasteable ``pip install`` command when one or more packages are missing. """ # pip name → importable name, for packages where they differ import_names = { "pillow": "PIL", "scikit-image": "skimage", "opencv-python": "cv2", "datalad-installer": "datalad_installer", "openneuro-py": "openneuro", "sonar-space": "sonar", } missing = [] for req in cls.requirements: name = req.split(">")[0].split("<")[0].split("=")[0].split("[")[0] spec_name = import_names.get(name, name.replace("-", "_")) if importlib.util.find_spec(spec_name) is None: missing.append(name) if missing: raise ModuleNotFoundError( f"{cls.__name__} requires packages that are not installed.\n" f" pip install {' '.join(missing)}" ) @classmethod def _exclude_from_cls_uid(cls) -> list[str]: return [] @tp.final # make sure nobody gets it wrong and override it def __post_init__(self) -> None: """This should not exist in subclasses, as we use pydantic's model_post_init""" @classmethod def __init_subclass__(cls) -> None: super().__init_subclass__() # get requirements from superclasses as well reqs = tuple(x.strip() for x in cls.requirements) for base in cls.__bases__: breqs = getattr(base, "requirements", ()) if breqs is not cls.requirements: reqs = breqs + reqs cls.requirements = reqs # check for _exclude_from_cls_uid override as attribute as it gets ignored exc_tag = "_exclude_from_cls_uid" if exc_tag in getattr(cls, "__private_attributes__", {}): msg = f"Class {cls.__name__!r} cannot have a private attr {exc_tag!r} " msg += f"use a method `def {exc_tag}(cls) -> list[str]` instead." msg += f"(defined in module: {cls.__module__})" raise TypeError(msg) @classmethod def _can_be_instantiated(cls) -> bool: return not cls.__name__.startswith(("Base", "_"))
[docs] class Frequency(float): """Sampling rate in Hz, with helpers for second/sample conversion. A ``float`` subclass that provides ``to_ind`` (seconds → sample index) and ``to_sec`` (sample index → seconds). Examples -------- >>> freq = Frequency(100.0) >>> freq.to_ind(0.5) # 0.5 s at 100 Hz → sample 50 50 >>> freq.to_sec(50) # sample 50 at 100 Hz → 0.5 s 0.5 .. admonition:: Design rationale — ``to_ind`` uses ``round()`` ``to_ind`` uses a single rule — ``round(seconds * freq)`` — for both start times and durations, rather than mixing floor/ceil depending on context. This minimizes worst-case alignment error to ±0.5 samples and keeps the conversion trivially predictable. """ @tp.overload def to_ind(self, seconds: float) -> int: ... @tp.overload # noqa def to_ind(self, seconds: np.ndarray) -> np.ndarray: # noqa ...
[docs] def to_ind(self, seconds: tp.Any) -> tp.Any: # noqa """Convert time in seconds to a sample index. Uses ``round(seconds * frequency)`` to produce a deterministic sample count for any given duration. """ if isinstance(seconds, np.ndarray): return np.round(seconds * self).astype(int) return int(round(seconds * self))
@tp.overload def to_sec(self, index: int) -> float: ... @tp.overload # noqa def to_sec(self, index: np.ndarray) -> np.ndarray: # noqa ...
[docs] def to_sec(self, index: tp.Any) -> tp.Any: # noqa """Converts a sample index to a time in seconds""" return index / self
@staticmethod def _yaml_representer(dumper, data): "Represents Frequency instances as floats in yamls" return dumper.represent_scalar("tag:yaml.org,2002:float", str(float(data)))
_UNSET_START = float("inf") _TA = tp.TypeVar("_TA", bound="TimedArray")
[docs] @DumpContext.register class TimedArray: """Numpy array annotated with time metadata. Carries ``frequency``, ``start``, ``duration``, and an optional ``header`` dict for domain-specific attributes (channel names, electrode positions, space info). Time is always the last dimension (when ``frequency > 0``). .. admonition:: Design rationale Attaching ``frequency`` and ``start`` to the array lets extractors handle time slicing uniformly — whether the data is a time series (``frequency > 0``, slicing by sample indices) or a single static representation (``frequency == 0``, no time dimension). Slicing and ``+=`` handle time alignment automatically (resampling and shifting to a common grid). Data can be backed by memmap for fast access without loading full arrays into memory. """ def __init__( self, *, # forbid positional frequency: float, start: float, data: np.ndarray | None = None, duration: float | None = None, aggregation: tp.Literal["sum", "mean"] = "sum", header: dict[str, tp.Any] | None = None, ) -> None: """ Parameters ---------- frequency: float Sampling frequency of the data. If >0, the last dimension of the data is time; if 0 the data has no time dimension. start: float Start time of the data in seconds. data: optional array If provided, the data with time as last dimension (when frequency > 0). duration: optional float Duration of the data. If ``data`` is also provided and ``frequency > 0``, the last dimension is checked for consistency. If ``data`` is not provided, shape is inferred from the first data added. aggregation: "sum" or "mean" Aggregation mode on the time domain when adding to the timed array. header: optional dict Domain-specific attributes (channel names, electrode positions, space info, etc.) persisted alongside the data. """ self.frequency = Frequency(frequency) self.start = start self.aggregation = aggregation self.header = header exp_size = 0 if duration is not None and duration < 0: raise ValueError(f"duration should be None or >=0, got {duration}") if data is None: if duration is None: raise ValueError("Missing data or duration") # post-poned initialization of data through __iadd__ # initialize with data.size == 0 if not frequency: data = np.zeros((0,)) else: exp_size = max(1, self.frequency.to_ind(duration)) data = np.zeros((0, exp_size)) self.data = data if frequency and duration is not None: exp_size = 0 if not duration else max(1, self.frequency.to_ind(duration)) if duration and not self.data.shape[-1]: msg = "Last dimension is empty but frequency and duration are not null " msg += f"(shape={self.data.shape})" raise ValueError(msg) if abs(data.shape[-1] - exp_size) > 1: msg = f"Data has incorrect (last) dimension {data.shape} for duration " msg += f"{duration} and frequency {frequency} (expected {exp_size})" raise ValueError(msg) if frequency: self.duration = self.frequency.to_sec(data.shape[-1]) elif duration is None: raise ValueError(f"duration must be provided if {frequency=}") else: self.duration = duration # averaging self._overlapping_data_count: None | np.ndarray = None if aggregation == "mean": num = self.data.shape[-1] if self.frequency else 1 self._overlapping_data_count = np.zeros(num, dtype=int) elif aggregation != "sum": raise ValueError(f"Unknown {aggregation=}") def __repr__(self) -> str: # Show shape/dtype only — stringifying data triggers numpy arrayprint, # which reads the underlying memmap and is catastrophic on network storage. cls = self.__class__.__name__ fields = "frequency,start,duration,aggregation".split(",") string = ",".join(f"{f}={getattr(self, f)}" for f in fields) return f"{cls}({string},data=<{self.data.shape} {self.data.dtype}>)" def __iadd__(self, other: "TimedArray") -> "TimedArray": if other.frequency and self.frequency != other.frequency: diff = abs(self.frequency - other.frequency) if diff * max(self.duration, other.duration) >= 0.5: # half sample diff msg = f"Cannot add with different (non-0) frequencies ({other.frequency} and {self.frequency})" raise ValueError(msg) if not self.data.size: # post-poned initialization of data, recover shape from other.data last = -1 if other.frequency else None shape = other.data.shape[:last] if self.frequency: shape += (self.data.shape[-1],) self.data = np.zeros(shape, dtype=other.data.dtype) if self.frequency: slices = [ sa1._overlap_slice(sa2.start, sa2.duration) for sa1, sa2 in [(self, other), (other, self)] ] if slices[0] is None or slices[1] is None: return self # no overlap # slices self_slice = slices[0][-1] other_slice = slices[1][-1] else: if self._overlap_slice(other.start, other.duration) is None: return self # no overlap self_slice = None other_slice = None # materialize ContiguousMemmap (file-IO proxy) before arithmetic self.data = np.asarray(self.data) other_data = np.asarray(other.data[..., other_slice]) if self._overlapping_data_count is None: # sum self.data[..., self_slice] += other_data else: # average counts = self._overlapping_data_count[..., self_slice] upd = counts / (1.0 + counts) self.data[..., self_slice] *= upd self.data[..., self_slice] += (1 - upd) * other_data counts += 1 return self def _overlap_slice( self, start: float, duration: float ) -> tuple[float, float, slice | None] | None: if self.start == _UNSET_START or start == _UNSET_START: raise RuntimeError( "Cannot compute overlap on a TimedArray with unset start time. " "Call with_start() first." ) if duration < 0: raise ValueError(f"duration should be >=0, got {duration=}") overlap_start = max(start, self.start) overlap_stop = min(start + duration, self.start + self.duration) if overlap_stop < overlap_start: return None # no overlap if overlap_stop == overlap_start and self.duration and duration: return None # 2 timed arrays with durations with one starting when the other ends if not self.frequency: return overlap_start, overlap_stop - overlap_start, None if not self.duration: return None # frequency but no duration -> empty start_ind = self.frequency.to_ind(overlap_start - self.start) duration_ind = self.frequency.to_ind(overlap_stop - overlap_start) # # # right edge border case # # # if duration_ind <= 0: # faster than max duration_ind = 1 # then make sure we move the start according to the number of selected samples tps = self.data.shape[-1] if start_ind > tps - duration_ind: start_ind = tps - duration_ind if start_ind < 0: raise RuntimeError(f"Fail for {start=} {duration=} on {self}") start = self.frequency.to_sec(start_ind) + self.start duration = self.frequency.to_sec(duration_ind) # # # build # # # out = start, duration, slice(start_ind, start_ind + duration_ind) return out
[docs] def with_start(self: _TA, start: float) -> _TA: """Return a lightweight copy sharing the data array with a new start time.""" cls = type(self) return cls( data=self.data, frequency=self.frequency, start=start, header=self.header, )
[docs] def overlap(self: _TA, start: float, duration: float) -> _TA: """Returns the sub TimedArray overlapping with the provided start and duration In case of lack of overlap, a timed array with 0 duration and empty data on the time dimension will be returned. """ if not self.frequency: msg = "Cannot call overlap with no time dimension (TimedArray.frequency = 0)" raise RuntimeError(msg) out = self._overlap_slice(start, duration) if out is not None: ostart, oduration, sl = out else: ostart, oduration, sl = min(start, self.start), 0, slice(0, 0) cls = type(self) return cls( frequency=self.frequency, start=ostart, duration=oduration, data=self.data[..., sl], header=self.header, )
# -- DumpContext serialization protocol -- def __dump_info__(self, ctx: tp.Any) -> dict[str, tp.Any]: # Time-first on disk: time-slice reads become contiguous (~9s → sequential on NFS) data = np.moveaxis(self.data, -1, 0) if self.frequency else self.data info: dict[str, tp.Any] = { "data": ctx.dump(data), "frequency": float(self.frequency), "start": self.start, "duration": self.duration, } if self.header: info["header"] = ctx.dump(self.header) return info @classmethod def __load_from_info__(cls, ctx: tp.Any, **content: tp.Any) -> "TimedArray": # file-IO proxy: reads into freeable heap buffers instead of faulting memmap pages ctx.options.replace.setdefault("MemmapArray", "ContiguousMemmapArray") data = ctx.load(content.pop("data")) if content.get("frequency"): data = np.moveaxis(data, 0, -1) # disk is time-first, memory is time-last if "header" in content: content["header"] = ctx.load(content["header"]) return cls(data=data, **content)
yaml.representer.SafeRepresenter.add_representer(Frequency, Frequency._yaml_representer)
[docs] class Step(exca.steps.Step, _Module, discriminator_key="name"): """Base class for composable pipeline nodes. A ``Step`` represents a single operation in a data processing pipeline, such as loading a study (:class:`~neuralset.events.Study`) or transforming events (:class:`~neuralset.events.EventsTransform`). Steps can be executed individually via ``.run()`` or grouped into a :class:`Chain`. Inherits caching and execution infrastructure from ``exca.steps.Step`` (configured via the ``infra`` parameter). Parameters ---------- infra : dict or exca.Infra, optional A pydantic config defining caching and execution infrastructure. It supports options like caching to disk (e.g. using the ``Cached`` backend) and executing remotely on a cluster (e.g. via ``Slurm``). For example: ``{"backend": "Cached", "folder": "~/.cache/neuralset"}``. """ # Chain.model_post_init reads this from the last step; # model_post_init also applies it to the step's own infra. _DEFAULT_CACHE_TYPE: tp.ClassVar[str | None] = None def model_post_init(self, __context: tp.Any) -> None: super().model_post_init(__context) if ( self._DEFAULT_CACHE_TYPE is not None and self.infra is not None and self.infra.cache_type is None ): self.infra.cache_type = self._DEFAULT_CACHE_TYPE @pydantic.model_validator(mode="wrap") @classmethod def _discover_study( cls, value: tp.Any, handler: pydantic.ValidatorFunctionWrapHandler, ) -> "Step": """Trigger lazy study-package scanning for unknown discriminator names.""" if isinstance(value, dict): name = value.get(cls._exca_discriminator_key) if name and name not in cls._get_discriminated_subclasses(): from .events import study as _study _study._resolve_study(name) return handler(value)
[docs] class Chain(exca.steps.Chain, Step): """A sequence of processing steps executed in order. A ``Chain`` groups multiple ``Step`` objects (such as :class:`~neuralset.events.Study` and :class:`~neuralset.events.EventsTransform`) into a single cohesive pipeline. Because a ``Chain`` is itself a ``Step``, it can be nested inside other chains or used anywhere a ``Step`` is expected. When you call ``.run()``, it passes the output of each step as the input to the next. Parameters ---------- steps : list of dict or dict of str to dict The ordered sequence of steps to execute. Since this is a pydantic config, it is strongly recommended to pass a list of dictionaries rather than instantiated objects (these dictionaries are coerced automatically). If a dict of dicts is provided, the keys act as step names. infra : dict or exca.Infra, optional A pydantic config for caching and execution infrastructure inherited from ``Step``. If provided, it determines how the final output of the chain is cached (e.g. using the ``Cached`` backend) or executed remotely (e.g. via ``Slurm``). Examples -------- .. code-block:: python chain = ns.Chain(steps=[ {"name": "MyStudy", "path": "/data", "infra": {"backend": "Cached", "folder": "/cache"}}, {"name": "QueryEvents", "query": "timeline_index < 5"}, ]) events = chain.run() """ steps: list[Step] | collections.OrderedDict[str, Step] # type: ignore def model_post_init(self, __context: tp.Any) -> None: super().model_post_init(__context) # Chain output matches last step: use its cache format (instance > class default). if self.infra is not None and self.infra.cache_type is None: seq = ( list(self.steps.values()) if isinstance(self.steps, dict) else self.steps ) if seq: last = seq[-1] if last.infra is not None and last.infra.cache_type is not None: self.infra.cache_type = last.infra.cache_type else: self.infra.cache_type = last._DEFAULT_CACHE_TYPE