Source code for neuralfetch.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""Utilities for neuralfetch study development."""

import ast
import inspect
import logging
import os
import subprocess
import sys
import tempfile
import typing as tp
from pathlib import Path

import pandas as pd
from tqdm import tqdm

import neuralset as ns
import neuralset.events as ev
from neuralfetch.download import download_file, extract_zip
from neuralset.events import study as base

logger = logging.getLogger(__name__)


[docs] def root_study_folder(name: str | None = None, test_folder: Path | None = None) -> Path: """Return the root folder where study data is stored. Example ------- >>> folder = neuralfetch.utils.root_study_folder() >>> study = ns.Study(name="Allen2022Massive", path=folder) Built-in test/sample studies use ``ns.CACHE_FOLDER`` (or *test_folder*). All others require ``NEURALSET_STUDY_FOLDER`` env var. """ if name is not None: if name.startswith(("Mne2013Sample", "Fake2025Meg", "Dummy")): return ns.CACHE_FOLDER if name.startswith(("Test", "Fake")): return test_folder if test_folder is not None else ns.CACHE_FOLDER env = os.environ.get("NEURALSET_STUDY_FOLDER") if env is None: raise RuntimeError( "NEURALSET_STUDY_FOLDER env var is not set.\n" "Export it to the root folder containing your study data, e.g.:\n" " export NEURALSET_STUDY_FOLDER=/path/to/root/studies/folder" ) return Path(env)
[docs] def compute_study_info(name: str, folder: str | Path) -> dict[str, tp.Any]: """Load study *name* from *folder* and return a dict of actual ``StudyInfo`` values. Always computes num_timelines, num_subjects, num_events_in_query, and event_types_in_query. Attempts to read one Fmri/MneRaw event for data_shape, frequency, and fmri_spaces (skipped on failure). """ folder = Path(folder) default_query = "timeline_index < 1" study = ns.Study(name=name, path=folder, query=default_query) cls = type(study) info = cls._info query = info.query if info is not None else default_query if query != default_query: study = ns.Study(name=name, path=folder, query=query) study.infra_timelines.cluster = None # avoid process pool startup cls._info = None # bypass num_timelines check during loading try: summary = study.study_summary(apply_query=False) events = study.run() finally: cls._info = info actual: dict[str, tp.Any] = dict( num_timelines=len(summary), num_subjects=summary.subject.nunique(), num_events_in_query=len(events), event_types_in_query=set(events["type"].unique()), ) # Read first Fmri/MneRaw event for data_shape / frequency. types = ev.etypes.EventTypesHelper(["Fmri", "MneRaw"]).names matching = events.loc[events.type.isin(types)] if matching.empty: return actual event = ev.Event.from_dict(matching.iloc[0]) data = event.read() # type: ignore if isinstance(event, ev.etypes.Fmri): actual["data_shape"] = data.shape fmri_types = ev.etypes.EventTypesHelper(["Fmri"]).names actual["fmri_spaces"] = set( matching.loc[matching.type.isin(fmri_types), "space"].unique() ) elif isinstance(event, ev.etypes.MneRaw): pick_map: dict[type, str | tuple[str, ...]] = { ev.etypes.Eeg: "eeg", ev.etypes.Emg: "emg", ev.etypes.Fnirs: "fnirs", ev.etypes.Ieeg: ("seeg", "ecog"), ev.etypes.Meg: "meg", } if isinstance(event, tuple(pick_map)): data.pick(pick_map[type(event)]) actual["data_shape"] = (len(data.ch_names), int(data.n_times)) actual["frequency"] = event.frequency # type: ignore[attr-defined] return actual
# --------------------------------------------------------------------------- # Source-file rewriting # --------------------------------------------------------------------------- def _find_info_lines(source: str, class_name: str) -> tuple[int, int]: """Return 1-indexed (start, end) line range of the ``_info`` assignment. Handles both annotated (``_info: ... = ...``) and plain (``_info = ...``) assignments. If ``_info`` is absent, returns an empty range before the first method so a splice inserts a new line there. """ tree = ast.parse(source) for node in ast.walk(tree): if not isinstance(node, ast.ClassDef) or node.name != class_name: continue fallback = node.body[-1].end_lineno or node.body[-1].lineno for item in node.body: if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): first_line = ( item.decorator_list[0].lineno if item.decorator_list else item.lineno ) fallback = first_line - 1 break target_name = None if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name): target_name = item.target.id elif isinstance(item, ast.Assign): for t in item.targets: if isinstance(t, ast.Name) and t.id == "_info": target_name = t.id if target_name == "_info": assert item.end_lineno is not None return item.lineno, item.end_lineno return fallback + 1, fallback # empty range: insert at fallback raise ValueError(f"class {class_name} not found") def _repr_val(val: tp.Any) -> str: """Deterministic repr: sorted sets, floats rounded to 3 decimals.""" if isinstance(val, set): return "{" + ", ".join(repr(x) for x in sorted(val)) + "}" if isinstance(val, float): return repr(round(val, 3)) return repr(val) def format_study_info(actual: dict[str, tp.Any]) -> str: """Return a formatted ``StudyInfo(...)`` string from computed values.""" parts = [ f"{f}={_repr_val(actual[f])}" for f in base.StudyInfo.model_fields if f != "query" and f in actual ] code = f"StudyInfo({', '.join(parts)})" result = subprocess.run( [ sys.executable, "-m", "ruff", "format", "--line-length=90", "--stdin-filename=_.py", ], input=code, capture_output=True, text=True, check=True, ) return result.stdout.strip()
[docs] def update_source_info(name: str, folder: str | Path | None = None) -> dict[str, tp.Any]: """Compute actual ``StudyInfo`` values, rewrite the source file, and run ``ruff format``. If *folder* is ``None``, uses the default study folder (or cache folder for test/fake studies). Returns the computed values dict. Usage:: python -c "from neuralfetch.utils import update_source_info; update_source_info('StudyName')" """ if folder is None: folder = root_study_folder(name) actual = compute_study_info(name, folder) info_str = format_study_info(actual) new_info = f" _info: tp.ClassVar[study.StudyInfo] = study.{info_str}\n" # Rewrite source file. cls = type(ns.Study(name=name, path=".")) source_file = inspect.getsourcefile(cls) if source_file is None: raise RuntimeError(f"Cannot locate source file for {name}") path = Path(source_file) source = path.read_text("utf8") lines = source.splitlines(keepends=True) start, end = _find_info_lines(source, cls.__name__) lines[start - 1 : end] = [new_info] path.write_text("".join(lines)) subprocess.run([sys.executable, "-m", "ruff", "format", str(path)], check=True) logger.info("Updated _info in %s", path) return actual
# --------------------------------------------------------------------------- # Study data utilities (moved from neuralset.studies.utils) # --------------------------------------------------------------------------- def add_sentences( events: pd.DataFrame, ratios: tuple[float, float, float] = (0.8, 0.1, 0.1), column_to_group: str = "sequence_id", ) -> pd.DataFrame: """ Add sentence-level information to the events DataFrame based on the sequence_id column. """ assert column_to_group in events.columns assert "Sentence" not in events.type.unique() if "timeline" in events.columns and len(events.timeline.unique()) > 1: # apply to each timeline timelines = [] for _, df in tqdm(events.groupby("timeline"), "Adding sentences"): df = add_sentences(df, ratios, column_to_group) timelines.append(df) return pd.concat(timelines) events["stop"] = events.start + events.duration words = events.query('type=="Word"') assert all(words.start.diff().dropna() >= 0) # type: ignore # Add sentence-level information words = events.query('type=="Word"') sentences = [] for _, sent in words.groupby(column_to_group, sort=False): # Find all events within the sentence duration = sent.stop.max() - sent.start.min() + 1e-8 sentence = " ".join(sent.text) events.loc[sent.index, "sentence"] = sentence # Add Sentence event to_add = sent.iloc[0].to_dict() to_add["type"] = "Sentence" to_add["text"] = " ".join(sent.text) to_add["start"] = sent.start.min() - 1e-8 to_add["duration"] = duration to_add["stop"] = sent.stop.max() sentences.append(to_add) events = pd.concat([events, pd.DataFrame(sentences)], ignore_index=True) events = events.sort_values("start") events = events.reset_index(drop=True) return events def scan_files( path: str | Path, stopping_criterion: tp.Callable[[str], bool] | None = None, ) -> tp.Iterator[str]: """Recursively yield file paths from given directory. Parameters ---------- path : Directory path to scan for files recursively. stopping_criterion : Callable that returns True if the current node/leaf should be yielded instead of scanned further. E.g. useful to catch recordings of neuro data format that are saved as folders, e.g. the CTF format (MEG) saves recordings as folders with a .ds extension. Note ---- About 2x faster than a combination of Path.iterdir and Path.rglob for walking through Obeid2016's directory. """ stopping_criterion = stopping_criterion or (lambda x: False) for entry in sorted(os.scandir(path), key=lambda e: e.name): if stopping_criterion(entry.name) or not entry.is_dir(follow_symlinks=False): yield entry.path else: yield from scan_files(entry.path, stopping_criterion=stopping_criterion) def ensure_imagenet22k( custom_path: Path | str | None = None, fail_on_error: bool = False, ) -> Path: """Ensure ImageNet-22k dataset exists, providing download instructions if not. Parameters ---------- custom_path : Path | str | None, optional Custom path where ImageNet-22k is or should be located. If None, reads ``NEURALSET_IMAGENET22K`` env var; raises if unset. fail_on_error : bool, optional If True, raise an exception if ImageNet-22k is not found. If False (default), log a warning and continue. Returns ------- Path Path to the ImageNet-22k directory (may not exist if not yet downloaded). """ if custom_path is None: env = os.environ.get("NEURALSET_IMAGENET22K") if env is None: raise RuntimeError( "ImageNet-22k path not configured.\n" "Please export NEURALSET_IMAGENET22K=/path/to/imagenet-22k " "or pass custom_path= to ensure_imagenet22k()." ) imagenet_path = Path(env) else: imagenet_path = Path(custom_path).resolve(strict=False) if imagenet_path.exists(): logger.info(f"ImageNet-22k found at {imagenet_path}") return imagenet_path logger.warning(f"ImageNet-22k not found at expected location: {imagenet_path}") logger.info("\n" + "=" * 80) logger.info("ImageNet-22k is required for natural image stimuli.") logger.info("=" * 80) logger.info("\nImageNet-22k (ImageNet-21k) Access Instructions:") logger.info("1. Visit the ImageNet website: https://image-net.org/") logger.info("2. Create an account or log in") logger.info("3. Review and agree to the ImageNet Terms of Access") logger.info("4. Request access to ImageNet-22k (Fall 2011 release or later)") logger.info( "5. Follow their download instructions (typically via direct download or torrent)" ) logger.info("6. Extract the dataset to maintain synset folder structure:") logger.info(" imagenet-22k/") logger.info(" n01440764/ (synset folders)") logger.info(" n01440764_00001.JPEG") logger.info(" ...") logger.info(" ...") logger.info(f"7. Place the dataset at: {imagenet_path}") logger.info("\nDataset Details:") logger.info("- Size: ~1.3 TB compressed, ~1.5 TB uncompressed") logger.info("- Images: ~14 million") logger.info("- Classes: ~21,841 WordNet synsets") logger.info("- Format: JPEG images organized by synset folders") logger.info("\nAlternative:") logger.info("- Check if your institution has a local copy") logger.info("- Contact your system administrator for access") logger.info("=" * 80) if fail_on_error: raise RuntimeError( f"ImageNet-22k required but not found at {imagenet_path}. " "Please obtain access from https://image-net.org/" ) logger.info( "\nContinuing without ImageNet-22k. Study may fail if natural image " "stimuli are required." ) return imagenet_path def download_things_images( study_path: Path, fail_on_error: bool = False, ) -> Path: """Download THINGS-images dataset if not already present. THINGS-images contains the full image database (~1,854 object concepts) organized by category folders. The dataset is shared across multiple THINGS-related studies (Hebart2023ThingsFmri, Gifford2022ThingsEeg, Grootswagers2022ThingsEeg2, etc.). Full license terms: https://osf.io/jum2f/files/52wrx Parameters ---------- study_path : Path Path to the study directory. THINGS-images should be in the parent directory as a sibling folder. fail_on_error : bool, optional If True, raise an exception if THINGS-images cannot be obtained. """ things_images_path = (study_path / ".." / "THINGS-images").resolve(strict=False) if things_images_path.exists(): logger.info(f"THINGS-images found at {things_images_path}") return things_images_path logger.warning(f"THINGS-images not found at expected location: {things_images_path}") password = os.environ.get("NEURALFETCH_THINGS_PASSWORD") if password is None: raise RuntimeError( "THINGS-images requires accepting a licence agreement.\n" "Please export NEURALFETCH_THINGS_PASSWORD=<pwd> where the password can be found " "in password_images.txt at https://osf.io/jum2f/files/52wrx" ) zip_url = "https://files.osf.io/v1/resources/jum2f/providers/osfstorage/670d6f7dbce84f4a7bb9371b" print("\nDownloading images_THINGS.zip...") logger.info(f"Downloading from {zip_url}") with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp_file: tmp_path = Path(tmp_file.name) try: download_file(zip_url, tmp_path, show_progress=True) print("Download complete!") extract_path = things_images_path.parent print(f"\nExtracting images_THINGS.zip to {extract_path}...") print("This may take ~1 hour...") logger.info(f"Extracting THINGS-images to {extract_path}") extract_zip(tmp_path, extract_path, password=password, remove_after=True) extracted_folder = extract_path / "object_images" extracted_folder.rename(things_images_path) print(f"\nTHINGS-images successfully installed at {things_images_path}") logger.info(f"THINGS-images installed at {things_images_path}") return things_images_path except Exception as e: logger.error(f"Error downloading or extracting THINGS-images: {e}") print(f"\nError: {e}") if tmp_path.exists(): tmp_path.unlink() if fail_on_error: raise RuntimeError(f"Failed to download THINGS-images: {e}") return things_images_path