# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import functools
import importlib.metadata
import logging
import os
import typing as tp
import warnings
from pathlib import Path
import exca
import exca.steps
import pandas as pd
import pydantic
import ujson
from neuralset import base
from . import etypes, utils
# re-export: study implementations use study.SpecialLoader
from .utils import SpecialLoader as SpecialLoader # isort: skip # pylint: disable=W0611
logger = logging.getLogger(__name__)
def _check_folder_path(path: base.PathLike, name: str) -> Path:
"""Check that the parent path exists and create directory"""
path = Path(path)
if not path.parent.exists():
raise RuntimeError(f"Parent folder {path.parent} of {name} must exist first.")
path.mkdir(exist_ok=True)
return path
def _identify_study_subfolder(path: str | Path, name: str) -> Path:
"""If provided with path and study name, if
path / name or path / name.lower() exist then returns the first.
This way path can be generic for all studies provided the actual study
folder is a sub-folder with its name.
"""
path = Path(path)
path = _check_folder_path(path, name=f"{name}.path")
if path.name.lower() != name.lower():
# use the subfolder with capitalized or uncapitalized name if it exists,
# this enables using same folder everywhere
for n in (name, name.lower()):
if (path / n).exists():
logger.debug("Updating study path to %s", path)
return path / n
return path
def _set_dir_permissions(path: Path) -> None:
"""Recursively set 777 permissions on a directory (Unix only).
Skips files/directories not owned by the current user, since only the
owner (or root) may chmod. This is expected on shared filesystems where
another user originally downloaded the study.
"""
if os.name == "nt":
logger.info("Skipping permission setting on Windows.")
return
current_uid = os.getuid()
skipped = 0
for root, dirs, files in os.walk(path):
for item in [root] + [os.path.join(root, n) for n in dirs + files]:
try:
if os.stat(item).st_uid != current_uid:
skipped += 1
continue
os.chmod(item, 0o777)
except PermissionError:
logger.debug("Cannot chmod %s (not owner), skipping.", item)
skipped += 1
if skipped:
logger.info("Skipped %d items in %s not owned by current user.", skipped, path)
logger.info(f"Permissions set for {path} (skipped {skipped} non-owned items).")
def _scan_package_for_studies(base_dir: Path, base_module: str, name: str) -> None:
"""Scan a package directory tree for study modules and import those
that may define a class matching *name*.
When *name* is empty (called from :meth:`Study.catalog`), every non-test
module is imported so that all ``__init_subclass__`` registrations fire.
Matches class definitions (``class Name``), module-level aliases
(``Name =``), and string literals (``"Name"``).
"""
for fp in base_dir.rglob("*.py"):
if fp.name.startswith(("test_", "_")) or "-" in fp.stem:
continue
rel = fp.relative_to(base_dir).with_suffix("")
module = base_module + "." + ".".join(rel.parts)
try:
text = fp.read_text("utf8")
except FileNotFoundError:
continue # editable installs can race with file creation
if not name or name in text:
importlib.import_module(module)
if name and name in STUDIES:
logger.debug("study %r found in %s", name, module)
return
@functools.lru_cache(maxsize=1)
def _get_study_packages() -> tuple[str, ...]:
"""Return deduplicated module paths registered under the
``neuralset.studies`` entry-point group. Each value is a dotted
module path used as the scan root for study discovery.
"""
packages = dict.fromkeys(
ep.value for ep in importlib.metadata.entry_points(group="neuralset.studies")
)
result = tuple(packages)
logger.debug("study packages discovered via entry points: %s", result)
return result
def _resolve_study(name: str = "") -> tp.Type["Study"] | None:
"""Look up *name* in STUDIES, scanning external packages if needed.
Returns the class if found, None otherwise.
Pass empty string to trigger a full scan of all study packages.
"""
if name and name in STUDIES:
return STUDIES[name]
for pkg_name in _get_study_packages():
if name and name in STUDIES:
break
try:
pkg = importlib.import_module(pkg_name)
except ImportError:
continue
if pkg.__file__ is None:
continue
_scan_package_for_studies(Path(pkg.__file__).parent, pkg_name, name)
cls = STUDIES.get(name)
if name and cls is None:
pkgs = _get_study_packages()
if pkgs:
scanned = ", ".join(pkgs)
msg = (
f"Study {name!r} not found (scanned: {scanned}). "
f"Import the module that defines it before use."
)
else:
msg = (
f"Study {name!r} not found. "
f"Install neuralfetch (pip install neuralfetch) "
f"or import the module that defines it before use."
)
raise ImportError(msg)
return cls
[docs]
class StudyInfo(base.BaseModel):
"""Records expected dataset characteristics for testing and validation.
Provides a baseline for automatic unit tests to verify that a study's
data loading logic and file parsing behave correctly and completely.
Attributes
----------
num_timelines : int
The total number of timelines (e.g. subject sessions) expected.
num_subjects : int
The expected number of unique subjects.
query : str
A query applied during tests to subsample the data (default: ``"timeline_index < 1"``).
num_events_in_query : int
The expected number of events after applying the query.
event_types_in_query : set of str
The expected set of event types present in the queried data.
data_shape : tuple of int
The expected shape of the primary data arrays.
frequency : float
The expected sampling frequency of the data, in Hz.
fmri_spaces : tuple of str
(fMRI only) The expected spatial reference spaces for the data.
"""
model_config = pydantic.ConfigDict(extra="forbid")
num_timelines: int = 0
num_subjects: int = 0
# from query
query: str = "timeline_index < 1"
num_events_in_query: int = 0
event_types_in_query: set[str] = set()
data_shape: tuple[int, ...] = ()
frequency: float = 0.0
# for FMRI only
fmri_spaces: tuple[str, ...] | set[str] = ()
# Lives here (not in transforms/) to avoid circular import: events → transforms → extractors
class EventsTransform(base.Step):
"""Base class for steps that modify an :term:`events <event>` dataframe.
Transforms take an input events DataFrame, modify it (e.g. by adding new
columns, filtering rows, or deriving new events), and return the modified
DataFrame.
Subclasses should override ``_run(self, events: pd.DataFrame) -> pd.DataFrame``.
Examples
--------
.. code-block:: python
class MyFilter(ns.EventsTransform):
def _run(self, events: pd.DataFrame) -> pd.DataFrame:
return events[events["type"] == "Audio"]
"""
_DEFAULT_CACHE_TYPE: tp.ClassVar[str | None] = "ValidatedParquet"
@classmethod
def __pydantic_init_subclass__(cls, **kwargs: tp.Any) -> None:
super().__pydantic_init_subclass__(**kwargs)
if "_run" in cls.__dict__:
original = cls.__dict__["_run"]
@functools.wraps(original)
def _validated_run(self: tp.Any, *args: tp.Any, **kw: tp.Any) -> pd.DataFrame:
result = original(self, *args, **kw)
return utils.standardize_events(result, auto_fill=False)
cls._run = _validated_run # type: ignore[assignment]
def __call__(self, events: pd.DataFrame) -> pd.DataFrame:
"""Standalone usage: delegates directly to _run."""
return self._run(events)
def _run(self, events: pd.DataFrame) -> pd.DataFrame:
raise NotImplementedError
# STUDY V2
STUDY_PATHS: dict[str, Path] = {}
STUDIES: dict[str, tp.Type["Study"]] = {}
def _dict_to_kv(d: dict[str, tp.Any]) -> str:
"""Format dict as ``key=val,key2=val2`` with sorted keys."""
return ",".join(f"{k}={v}" for k, v in sorted(d.items()))
[docs]
class Study(base.Step):
"""Interface to an external dataset: loads :term:`events <event>` from raw recordings.
Subclass ``Study`` to create an interface to a new dataset. Override
:meth:`iter_timelines` to enumerate :term:`timelines <timeline>` and
:meth:`_load_timeline_events` to load events for each one.
Parameters
----------
path : Path
Root directory for the study data. All studies can use the same
``path`` (e.g. ``path="/data"``): each study automatically
resolves its own subfolder (e.g. ``/data/MyStudy``), so you
only need to configure one path for your entire data store.
infra_timelines : MapInfra
Caching/compute backend for per-timeline event loading. Uses
multiprocessing by default (``cluster="processpool"``); set
``cluster=None`` to disable (slower, but easier to debug).
query : Query or None
Optional filter applied after loading (e.g. ``"timeline_index < 5"``).
Examples
--------
.. code-block:: python
# Direct instantiation:
study = MyStudy(path="/data/studies")
# path auto-resolves to /data/studies/MyStudy if that subfolder exists
events = study.run()
# By name (auto-imports the study class):
study = Study(name="MyStudy", path="/data/studies")
events = study.run()
.. note::
Studies reference third-party datasets that are subject to their own
licenses. You may have other legal obligations or restrictions that
govern your use of that content. Check each study's :attr:`licence`
and :attr:`url` attributes for details.
"""
_DEFAULT_CACHE_TYPE: tp.ClassVar[str | None] = "ValidatedParquet"
path: Path
infra_timelines: exca.MapInfra = exca.MapInfra(cluster="processpool")
query: base.Query | None = None
# internal
_cls_string: str = "" # cache
# Class level info
_info: tp.ClassVar[None | StudyInfo] = None # for easy testing
aliases: tp.ClassVar[tuple[str, ...]] = ()
url: tp.ClassVar[str] = ""
bibtex: tp.ClassVar[str] = ""
licence: tp.ClassVar[str] = ""
description: tp.ClassVar[str] = ""
[docs]
@classmethod
def catalog(cls) -> dict[str, tp.Type["Study"]]:
"""All registered Study subclasses, keyed by name.
Triggers lazy imports so that studies from all installed
packages (e.g. neuralfetch) are discovered.
Returns
-------
dict[str, type[Study]]
``{name: StudySubclass}`` for every registered study.
"""
_resolve_study()
out = dict(STUDIES)
if not out:
raise ImportError(
"No studies found. Install a study package (e.g. pip install neuralfetch) "
"or import your own Study subclass before calling catalog()."
)
return out
[docs]
@classmethod
def neuro_types(cls) -> frozenset[str]:
"""Neural recording event types from ``_info``, empty if ``_info`` is not set."""
if cls._info is None:
return frozenset()
names = set(etypes.EventTypesHelper(("MneRaw", "Fmri")).names)
return frozenset(cls._info.event_types_in_query & names)
if tp.TYPE_CHECKING:
def __init__( # pylint: disable=super-init-not-called
self, **kwargs: tp.Any
) -> None: ... # override needed: __new__ breaks mypy field detection
def __new__(cls, /, **kwargs: tp.Any) -> "Study":
# __new__ because discovery must happen before DiscriminatedModel dispatches
# (model_validator runs too late — after __new__ already picked the class)
name = kwargs.get(cls._exca_discriminator_key, "")
if name:
_resolve_study(name)
return super().__new__(cls, **kwargs) # type: ignore[return-value]
def iter_timelines(self) -> tp.Iterator[dict[str, tp.Any]]:
raise NotImplementedError
# # example:
# for subject in range(20):
# for session in range(10):
# yield dict(subject=subject, session=session)
def _load_timeline_events(self, timeline: dict[str, tp.Any]) -> pd.DataFrame:
raise NotImplementedError
# # example
# event = dict(
# type="Eeg",
# subject=params["subject"],
# start=0,
# duration=100,
# filepath=f"{params['subject']}-{params['session']}.fif",
# )
# return pd.DataFrame([event])
def _download(self) -> None:
"""Download dataset.
Needs to be overriden by user.
"""
raise NotImplementedError("Dataset not available to download yet.")
@tp.final
def download(self, **kwargs: tp.Any) -> None:
self._check_requirements()
# Ensure download goes into a subfolder named after the study
name = self.__class__.__name__
if self.path.name.lower() != name.lower():
self.path = self.path / name
STUDY_PATHS[self.__class__.__name__] = self.path
logger.info("Download path updated to %s", self.path)
self.path.mkdir(parents=True, exist_ok=True)
self._download(**kwargs)
if not self.path.exists():
raise RuntimeError(f"Path does not exist: {self.path}")
if not self.path.is_dir():
raise RuntimeError(f"Path is not a directory: {self.path}")
if not any(self.path.iterdir()):
raise RuntimeError(f"Directory is empty: {self.path}")
logger.info(f"Success: Study downloaded to {self.path}.")
_set_dir_permissions(self.path)
self.clear_cache()
def _cls_kwargs(self) -> dict[str, tp.Any]:
"""Descriptor for the study instance parametrization"""
cls_kwargs: tp.Any = self.model_dump(serialize_as_any=True, exclude_defaults=True)
# Exclude standard fields from class kwargs
for p in ["infra", "infra_timelines", "path", "name", "query"]:
cls_kwargs.pop(p, None)
if cls_kwargs:
# should the class parameter be part of the timeline? or does
# it select a subset? the behavior is unclear and should be
# specified precisely first.
msg = "Class parameters are not yet supported, bring up your use-case!"
raise RuntimeError(msg)
return cls_kwargs
def _to_timeline_string(self, timeline: dict[str, tp.Any]):
if not self._cls_string:
self._cls_string = self.__class__.__name__
cls_kwargs = self._cls_kwargs()
if cls_kwargs:
self._cls_string += ":" + _dict_to_kv(cls_kwargs)
return self._cls_string + ":" + _dict_to_kv(timeline)
@classmethod
def _exclude_from_cls_uid(cls) -> list[str]:
return super()._exclude_from_cls_uid() + ["path"]
def model_post_init(self, log__: tp.Any) -> None:
super().model_post_init(log__)
if type(self) is Study:
raise TypeError(
"Study cannot be instantiated directly — use a subclass, "
"or pass name= to dispatch: Study(name='MyStudy2024', path=...)"
)
name = self.__class__.__name__
self.path = _identify_study_subfolder(self.path, name)
STUDY_PATHS[self.__class__.__name__] = self.path # record for path lookup
# Auto-propagate cache folder and mode so users only need to set it once
if self.infra is not None:
if self.infra.folder is not None and self.infra_timelines.folder is None:
self.infra_timelines.folder = self.infra.folder
if "mode" in self.infra.model_fields_set:
if "mode" not in self.infra_timelines.model_fields_set:
mode = self.infra.mode
self.infra_timelines.mode = "cached" if mode == "retry" else mode
def __init_subclass__(cls, **kwargs: tp.Any) -> None:
name = cls.__name__
super().__init_subclass__(**kwargs)
if not name.startswith("_"):
existing = STUDIES.get(name)
if existing is not None and existing is not cls:
raise RuntimeError(
f"Study name collision: {name!r} is already registered to "
f"{existing.__module__}.{existing.__qualname__}, "
f"cannot re-register to {cls.__module__}.{cls.__qualname__}"
)
STUDIES[name] = cls
if hasattr(cls, "version"):
msg = (
f"{name}.version must be specified through {name}.infra_timelines.version"
)
raise RuntimeError(msg)
def __setstate__(self, state: dict[str, tp.Any]) -> None:
super().__setstate__(state)
# reregister study path
STUDY_PATHS[self.__class__.__name__] = self.path
@infra_timelines.apply(
item_uid=lambda x: ujson.dumps(x, sort_keys=True),
exclude_from_cache_uid=("query",),
cache_type="ValidatedParquet", # preserves str dtypes (CSV doesn't)
)
def _load_timelines(
self, timelines: tp.Iterable[dict[str, tp.Any]]
) -> tp.Iterator[pd.DataFrame]:
"""Loads raw timelines and cache them"""
cls_name = self.__class__.__name__
for timeline in timelines:
if "subject" not in timeline:
raise RuntimeError("timeline dict must contain 'subject' key")
out = self._load_timeline_events(timeline)
# Core columns are always overwritten
out.loc[:, "subject"] = f"{cls_name}/{timeline['subject']}"
out.loc[:, "timeline"] = self._to_timeline_string(timeline)
out.loc[:, "study"] = cls_name
# Extra columns from timeline dict: conflict-check then set
extra: dict[str, str] = {}
for key, value in timeline.items():
if key not in ("subject", "path", "timeline"):
extra[key] = str(value)
for entity in utils.BIDS_ENTITIES:
if entity != "subject":
extra.setdefault(entity, utils.BIDS_ENTITY_DEFAULT)
for col, value in extra.items():
if col not in out.columns:
out.loc[:, col] = value
elif (out[col].astype(str) == value).all():
warnings.warn(
f"Column '{col}' from timeline dict already exists "
f"in the events dataframe with matching values. "
f"Remove it from _load_timeline_events to avoid "
f"future errors.",
FutureWarning,
)
elif (
col in utils.BIDS_ENTITIES
and (out[col].astype(str) == utils.BIDS_ENTITY_DEFAULT).all()
):
out.loc[:, col] = value
else:
raise ValueError(
f"Column '{col}' from timeline dict already exists "
f"in the events dataframe with different values."
)
out = utils.standardize_events(out)
yield out
def _run(self, events: pd.DataFrame | None = None) -> pd.DataFrame:
"""Load study data and optionally concatenate with existing events."""
if events is not None: # special case, concatenate
df = self._run()
return pd.concat([events, df], ignore_index=True).reset_index(drop=True)
name = self.__class__.__name__
timelines = []
# iterate 1 by 1 to provide explicit msg to different kinds of bugs
try:
for tl in self.iter_timelines():
timelines.append(tl)
except Exception as e:
# for a bug on the first timeline with no folder, raise with more info,
if not timelines and not self.path.exists():
msg = f"For {name}, you may need to run study.download() first "
msg += f"as {self.path} does not exist."
raise RuntimeError(msg) from e
raise
if not timelines:
raise RuntimeError(f"No timeline found for {name} in {self.path}")
# verify number of timelines in info
if self._info is not None:
tls = self._info.num_timelines
if tls != len(timelines):
msg = f"Dataset {name} is corrupted, expected {tls} timelines "
msg += f"but found {len(timelines)} (check/redownload dataset "
msg += f"folder {self.path} or update study class)"
raise RuntimeError(msg)
# filter through summary
if self.query is not None:
summ = self.study_summary(apply_query=True)
timelines = [timelines[k] for k in summ.index]
if not timelines:
msg = f"Did not find any timeline for {name}.query={self.query} "
msg += f"with summary:\n{self.study_summary(apply_query=False)}"
raise RuntimeError(msg)
# load and concatenate
timelines_events = list(self._load_timelines(timelines))
out = pd.concat(timelines_events).reset_index(drop=True)
return utils.standardize_events(out, auto_fill=False)
[docs]
def build(self) -> pd.DataFrame:
"""alias for run"""
return self.run()
[docs]
def study_summary(self, apply_query: bool = True) -> pd.DataFrame:
"""Returns a dataframe with 1 row per timeline and study attributes as columns.
:code:`query` parameter is used on this dataframe for subselection
Parameter
---------
apply_query: bool
if False returns the full the summary, otherwise filter it
according to the query
Virtual query columns
---------------------
The following columns are **not** present in the returned DataFrame
but are auto-generated by :func:`~neuralset.events.transforms.query_with_index`
when ``apply_query=True`` and the ``query`` string references them:
:code:`subject_index`: int
the index of the subject in the study
:code:`timeline_index`: int
the index of the timeline in the study (equivalent to "index")
:code:`subject_timeline_index`: int
the index of the timeline among a subject's timelines in the study
(used for querying at most :code:`n` timelines per subjects)
"""
name = self.__class__.__name__
tls = list(self.iter_timelines())
out = pd.DataFrame(tls)
if out.empty:
raise RuntimeError(f"No timeline found for {self!r}")
out["subject"] = out["subject"].apply(lambda x: f"{name}/{x}")
out.loc[:, "timeline"] = [self._to_timeline_string(tl) for tl in tls]
out = out.sort_values("subject", kind="stable")
if apply_query and self.query is not None:
out = utils.query_with_index(out, self.query)
return out