Source code for neuralset.events.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""Event utilities: extraction, validation, and BIDS entity constants."""

import inspect
import logging
import re
import typing as tp
import warnings
from collections import defaultdict
from pathlib import Path

import pandas as pd
import ujson
from exca.cachedict import DumpContext
from exca.cachedict.handlers import ParquetPandasDataFrame

from neuralset import utils

from .etypes import Event, EventTypesHelper

logger = logging.getLogger(__name__)
TypesParam = str | tp.Sequence[str] | tp.Type[Event] | EventTypesHelper

# Standard BIDS entity columns enforced on all events DataFrames
# (not fields on Event classes -- managed at the DataFrame level)
BIDS_ENTITIES = ("subject", "session", "task", "run")
BIDS_ENTITY_DEFAULT = ""


def extract_events(obj: tp.Any, types: TypesParam | None = None) -> list[Event]:
    """Returns a list of neuralset events extracted from the input parameter.

    Parameters
    ----------
    obj: dataframe, event, or list of events/segments
        the object to extract the list of events from, generally a dataframe or a list of
        segments
    types: optional str/list of str/ Event type/EventTypesHelper
        filters only the provided type(s)

    Returns
    -------
    list of Event
        the list of events extracted from the input object, possibly filtered by the
        provided type
    """
    from neuralset import segments  # lazy to avoid circular imports

    helper: EventTypesHelper | None = None
    if isinstance(types, EventTypesHelper):
        helper = types
    elif types is not None:
        helper = EventTypesHelper(types)
    # fast track for lists as it's the most common
    if isinstance(obj, (list, tuple)):
        if not obj:
            return []
        if isinstance(obj[0], Event):
            if helper is not None:
                obj = [e for e in obj if isinstance(e, helper.classes)]
            return obj
    if isinstance(obj, pd.DataFrame):
        if helper is not None:
            obj = obj.loc[obj.type.isin(helper.names), :]
        unknown = set(obj.type) - set(Event._CLASSES)
        if unknown:
            logger.warning("Ignoring unknown event types: %s", unknown)
            obj = obj.loc[~obj.type.isin(unknown), :]
        # skip itertuple if only one/two event :) (pandas is slooow)
        num = len(obj)
        iterable = (obj.iloc[k, :] for k in range(num)) if num <= 2 else obj.itertuples()
        out = [Event.from_dict(r) for r in iterable]
        for i, e in zip(obj.index, out):
            e._index = i  # noqa
        return out
    if isinstance(obj, Event):
        obj = [obj]
    elif isinstance(obj, (dict, pd.Series)):
        obj = [Event.from_dict(obj)]
    if isinstance(obj, segments.Segment):
        obj = [obj]
    if not isinstance(obj, (list, tuple)):
        raise NotImplementedError(f"Conversion of {type(obj)} is not supported")
    if not obj:
        return []
    if isinstance(obj[0], segments.Segment):
        event_dict = {}
        for segment in obj:
            event_dict.update({id(e): e for e in segment.ns_events})
            if segment.trigger is not None:
                event_dict[id(segment.trigger)] = segment.trigger
        obj = list(event_dict.values())
    if not isinstance(obj[0], Event):
        raise NotImplementedError(f"Unexpected list of {type(obj[0])} is not supported")
    return extract_events(obj, types=helper)


def expand_bids_fmri(
    pattern: str, preproc: str, **common: tp.Any
) -> list[dict[str, tp.Any]]:
    """Resolve a BIDS glob pattern into one Fmri event dict per space.

    Parameters
    ----------
    pattern
        Glob pattern ending with ``*`` (e.g.
        ``"derivatives/deepprep/.../sub-01_ses-01_task-rest_run-1_*"``).
    preproc
        Preprocessing pipeline name (``"deepprep"``, ``"fmriprep"``).
    **common
        Fields shared across all output dicts (``start``, ``frequency``, …).

    Returns
    -------
    list[dict]
        One dict per discovered space, ready for ``pd.DataFrame``.
    """
    from bids.layout import parse_file_entities  # type: ignore[import-untyped]

    fp_pattern = Path(pattern)
    fps: defaultdict[str, dict[str, str]] = defaultdict(dict)
    for ext in (".nii.gz", ".gii"):
        for fp in fp_pattern.parent.glob(fp_pattern.name + ext):
            ents = parse_file_entities(str(fp))
            space = ents.get("space")
            if space is None:
                continue
            if "hemi" in ents:
                part = "left" if ents["hemi"] == "L" else "right"
            elif ents.get("suffix") == "mask":
                part = "mask"
            elif ents.get("suffix") == "bold" and ents.get("desc") == "preproc":
                part = "data"
            else:
                continue
            fps[space][part] = str(fp)
    typed: list[dict[str, tp.Any]] = []
    for space, parts in fps.items():
        kw: dict[str, tp.Any] = {
            **common,
            "space": space,
            "preproc": preproc,
            "type": "Fmri",
        }
        if "left" in parts:
            # Only the left-hemi path is stored; Fmri._read() derives
            # the right hemisphere by replacing hemi-L with hemi-R.
            kw.update(filepath=parts["left"])
        elif "data" in parts:
            kw.update(filepath=parts["data"], mask_filepath=parts.get("mask"))
        else:
            continue
        typed.append(kw)
    return typed


def query_with_index(df: pd.DataFrame, query: str) -> pd.DataFrame:
    """Execute a pandas query with auto-generated index columns.

    Recognises ``<col>_index`` and ``<col1>_<col2>_index`` tokens in the
    query string and materialises them as temporary columns before executing
    the query. Temporary columns are dropped from the result.

    Example::

        # Filter by subject name
        query_with_index(df, 'subject == "Subject1"')

        # Keep only the first 2 subjects
        query_with_index(df, "subject_index < 2")

        # Keep only the first timeline per subject
        query_with_index(df, "subject_timeline_index < 1")
    """
    out = df.copy()
    temp_cols: list[str] = []
    columns = set(df.columns)

    for token in re.findall(r"\b(\w+_index)\b", query):
        if token in temp_cols:
            continue
        if token in columns:
            raise ValueError(
                f"Ambiguous index token {token!r}: a column with that name "
                "already exists in the dataframe."
            )
        prefix = token.rsplit("_index", 1)[0]
        if prefix in columns:
            out.loc[:, token] = out.groupby(prefix, sort=False).ngroup()
            temp_cols.append(token)
        else:
            matched = False
            for col in sorted(columns, key=len, reverse=True):
                if not prefix.startswith(col + "_"):
                    continue
                remainder = prefix[len(col) + 1 :]
                if remainder in columns:
                    out.loc[:, token] = out.groupby(col, sort=False)[remainder].transform(
                        lambda x: pd.factorize(x)[0]
                    )
                    temp_cols.append(token)
                    matched = True
                    break
            if not matched:
                raise ValueError(
                    f"Cannot resolve index token {token!r}: "
                    "no matching columns found in the dataframe"
                )

    result = out.query(query, engine="python")
    if temp_cols:
        result = result.drop(columns=temp_cols)
    return result


def standardize_events(
    events: pd.DataFrame,
    auto_fill: bool = True,
) -> pd.DataFrame:
    """Normalize an events DataFrame into canonical form.

    Sorts by (timeline, start ASC, duration DESC) preserving first-seen
    timeline order, fills BIDS entity columns, reorders columns, and adds
    a computed ``stop`` column. Always returns a new DataFrame.

    Parameters
    ----------
    events : pd.DataFrame
        DataFrame containing events. Must have at least a ``type`` column
        with string values.
    auto_fill : bool
        If True (default), round-trip each row through its pydantic Event
        model (``from_dict`` → ``to_dict``). This fills missing fields
        with defaults, validates types, and coerces values. This is
        **significantly slower** than the rest of the normalization
        (~10 s on 100k rows vs ~60 ms) and is only needed the first time
        raw events are ingested. Pass ``False`` when the data has already
        been through ``auto_fill`` (e.g. post-concat re-normalization,
        transform wrapper, cache load).

    Returns
    -------
    pd.DataFrame
        Normalized DataFrame. The ``stop`` column (``start + duration``)
        signals that normalization has been applied.
    """
    if events.empty:
        raise ValueError("Cannot normalize an empty events DataFrame.")
    for name in ["index", "Index"]:
        if name in events.columns:
            msg = f"The events dataframe contains an `{name}` column. This is "
            msg += "dangerous, please add drop=True in calls to df.reset_index(). "
            msg += "Dropping it automatically."
            warnings.warn(msg)
            events = events.drop(columns=[name])
    msg = 'events DataFrame must have a "type" column with strings'
    if "type" not in events.keys():
        raise ValueError(msg)
    types = events["type"].unique()
    if not all(isinstance(typ, str) for typ in types):
        raise ValueError(msg)
    if auto_fill:
        validated = events.apply(_coerce_event, axis=1)
        df = pd.DataFrame(validated.tolist(), index=events.index)
        null = df.loc[df.duration <= 0, :]
        if not null.empty:
            types = null["type"].unique()
            msg = f"Found {len(null)} event(s) with null duration (types: {types})"
            warnings.warn(msg)
    else:
        df = events
    for col in ("start", "duration"):
        if col not in df.columns:
            raise ValueError(f"Events DataFrame is missing required column '{col}'.")
        if df[col].isna().any():
            n = df[col].isna().sum()
            raise ValueError(
                f"Events DataFrame has {n} row(s) with NaN {col}. "
                f"Use auto_fill=True if events were built without {col}."
            )
    if "timeline" in df.columns and df["timeline"].isna().any():
        n = df["timeline"].isna().sum()
        raise ValueError(
            f"Events DataFrame contains {n} row(s) with NaN timeline. "
            "All events must have a valid 'timeline' value."
        )
    # Sort by (timeline, start ASC, duration DESC) preserving first-seen timeline order
    tl_order = pd.Categorical(
        df["timeline"], categories=df["timeline"].unique(), ordered=True
    )
    df = df.assign(_tl_order=tl_order.codes)
    df = df.sort_values(
        ["_tl_order", "start", "duration"],
        ascending=[True, True, False],
        ignore_index=True,
    )
    df = df.drop(columns=["_tl_order"])
    for entity in BIDS_ENTITIES:
        if entity not in df.columns:
            df[entity] = BIDS_ENTITY_DEFAULT
        df[entity] = df[entity].fillna(BIDS_ENTITY_DEFAULT).astype(str)
    important = ["type", "start", "duration", "timeline"] + list(BIDS_ENTITIES)
    columns = important + [c for c in df.columns if c not in important]
    df = df.loc[:, columns]
    df = df.assign(stop=lambda x: x.start + x.duration)
    return df


def _coerce_event(event: pd.Series) -> dict[str, tp.Any]:
    """Coerce a single event row through its pydantic Event model.

    Instantiates the Event subclass for this row's ``type``, which validates
    fields and fills defaults, then converts back to a dict.
    """
    from . import etypes

    event_type = event["type"]
    lower = {x.lower() for x in etypes.Event._CLASSES}
    if event_type in etypes.Event._CLASSES:
        event_class = etypes.Event._CLASSES[event_type]
        event_obj = event_class.from_dict(event).to_dict()
        event_dict = {**event, **event_obj}
    elif event_type in lower:
        raise ValueError(
            f"Unknown event type {event_type!r} (did you mean {event_type.title()!r}?)"
        )
    else:
        utils.warn_once(
            f'Unexpected type "{event["type"]}". Support for new event '
            "types can be added by creating new `Event` classes in "
            "`neuralset.events`."
        )
        event_dict = {**event}

    return event_dict


@DumpContext.register
class ValidatedParquet(ParquetPandasDataFrame):
    """Parquet cache type that coerces events before dumping and
    normalizes on load.

    Dump: coerces each event through its pydantic model (safety net),
    strips the derived ``stop`` column before writing.
    Load: normalizes (re-sorts, fills BIDS columns, recomputes ``stop``).
    """

    @classmethod
    def __dump_info__(cls, ctx: DumpContext, value: tp.Any) -> dict[str, tp.Any]:
        value = standardize_events(value)
        value = value.drop(columns=["stop"], errors="ignore")
        # Cast object-dtype columns with mixed types (e.g. int + str) to str
        # so pyarrow can serialize them without ArrowInvalid errors.
        for col in value.columns:
            if value[col].dtype == object:
                types = {type(v) for v in value[col].dropna()}
                if len(types) > 1:
                    value[col] = value[col].astype(str)
        return super().__dump_info__(ctx, value)

    @classmethod
    def __load_from_info__(cls, ctx: DumpContext, filename: str) -> tp.Any:
        df = pd.read_parquet(ctx.folder / filename)
        return standardize_events(df, auto_fill=False)


[docs] class SpecialLoader: """Loader for special methods that need to be serialized and called later. Parameters ---------- method: method method to use for loading timeline: dict[str, Any] parameters defining the timeline (subject, task etc) **kwargs: Any any additional (json-able) parameters used to call the method Example ------- In ``_load_timeline_events``:: loader = SpecialLoader(method=self._my_method, timeline=timeline, additional_param=...) In the method:: def _my_method(self, timeline: dict[str, tp.Any], additional_param): ... """ def __init__( self, method: tp.Callable[..., tp.Any], timeline: dict[str, tp.Any], **kwargs: tp.Any, ) -> None: self.method = method self.timeline = timeline self.kwargs = kwargs @classmethod def from_json(cls, string: str) -> "SpecialLoader": from .study import STUDIES, STUDY_PATHS # deferred: study imports utils data = ujson.loads(string) name = data["cls"] if name not in STUDIES or name not in STUDY_PATHS: raise RuntimeError( f"Study class {name!r} is not available in this process " f"(registered: {list(STUDIES)}, with path: {list(STUDY_PATHS)}).\n" "A Study instance must be constructed before SpecialLoader.from_json " "can be used — in a child job / subprocess, make sure the study " "module is imported and the Study is instantiated." ) scls = STUDIES[name] study = scls(path=STUDY_PATHS[name], **data.get("cls_kwargs", {})) method = getattr(study, data["method"]) kwargs = data.get("kwargs", {}) return cls(method=method, timeline=data["timeline"], **kwargs) def load(self) -> tp.Any: return self.method(timeline=self.timeline, **self.kwargs) # type: ignore def to_json(self) -> str: inst = self.method.__self__ # type: ignore data = { "cls": inst.__class__.__name__, "timeline": self.timeline, "method": self.method.__name__, } if self.kwargs: data["kwargs"] = self.kwargs cls_kwargs = inst._cls_kwargs() if cls_kwargs: data["cls_kwargs"] = cls_kwargs return ujson.dumps(data, sort_keys=True)
[docs] def specs(self) -> dict[str, list[str]]: """Extract configurable parameters from the method's Literal type hints. Returns a dict mapping parameter names to their allowed values, e.g. ``{"registration": ["msmall", "none", "mni"], ...}``. """ hints = tp.get_type_hints(self.method) sig = inspect.signature(self.method) out: dict[str, list[str]] = {} for name, _param in sig.parameters.items(): if name in ("self", "timeline"): continue hint = hints.get(name) if hint is None: continue args = tp.get_args(hint) if args and all(isinstance(a, str) for a in args): out[name] = list(args) return out
[docs] def specs_json(self) -> str: """JSON string of :meth:`specs`, suitable for a DataFrame column.""" return ujson.dumps(self.specs(), sort_keys=True)