Source code for neuralbench.transforms

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import numpy as np
import pandas as pd
import pydantic
from sklearn.model_selection import train_test_split

import neuralset as ns
from neuralset.events import transforms as _transf


[docs] class TextPreprocessor(_transf.EventsTransform): """Clean and filter text-related events. The following operations are applied to the events: - Keep only events with duration >= 0 - Keep only neuro events, Audio, or valid Word events (with text as string) - Clean 'text' column by removing special characters and lowercasing - Drop empty or blank text entries - For Nieuwland2018, group similar sentences together to avoid leakage (each sentence has two very similar versions) """ neuro_event_type: str = "Eeg"
[docs] @staticmethod def clean_text(x: str) -> str: """Remove special characters and lowercase the text.""" return "".join(e for e in x if e.isalnum() or e in [" ", "-", "'"]).lower()
def _run(self, events: pd.DataFrame) -> pd.DataFrame: # Keep only valid events based on type and text validity is_valid_word = events.text.apply(lambda x: isinstance(x, str)) keep_event = (events.duration >= 0) & ( events.type.isin([self.neuro_event_type, "Audio"]) | is_valid_word ) events = events.loc[keep_event].copy() # Clean the 'text' column for Word events word_mask = events.type == "Word" events.loc[word_mask, "text"] = events.loc[word_mask, "text"].apply( self.clean_text ) # Drop blank or empty text entries events = events.loc[~events.text.isin(["", " "])].copy() if events.study.iloc[0] == "Nieuwland2018": events["sequence_id"] = events.groupby("sentence").ngroup() // 2 return events
SklearnSplit = _transf.SklearnSplit
[docs] class SimilaritySplit(_transf.EventsTransform): """Perform train/val/test split based on similarity of sentence events. Depending on the type of stimulus event that is expected, the behavior is as follows: - For Audio events, propagate sentence mapping to Word events, then chunk Audio events based on Word events. - For Keystroke events, propagate sentence mapping to Keystroke events. - For Sentence or Word events, directly apply the similarity-based split. Parameters ---------- use_sklearn_split : If True, use sklearn's `train_test_split` after computing clusters, rather than using `SimilaritySplitter`'s deterministic cluster assignment. NOTE: `valid_random_state` and `test_random_state` are ignored unless `use_sklearn_split` is True. """ stim_event_type: tp.Literal["Sentence", "Word", "Audio", "Keystroke"] valid_split_ratio: float = pydantic.Field( 0.2, strict=True, ge=0.0, le=1.0, allow_inf_nan=False ) test_split_ratio: float = pydantic.Field( 0.2, strict=True, ge=0.0, le=1.0, allow_inf_nan=False ) valid_random_state: int = 33 test_random_state: int = 33 threshold: float = 0.2 use_sklearn_split: bool = False def _split(self, events: pd.DataFrame, splitter: _transf.SimilaritySplitter): if self.use_sklearn_split: splitted_events = splitter.compute_clusters(events) sklearn_splitter = _transf.SklearnSplit( split_by="cluster_id", valid_split_ratio=self.valid_split_ratio, test_split_ratio=self.test_split_ratio, valid_random_state=self.valid_random_state, test_random_state=self.test_random_state, ) splitted_events = sklearn_splitter(splitted_events) events.loc[splitted_events.index, "split"] = splitted_events["split"] else: # Use SimilaritySplitter's deterministic cluster assignment events = splitter(events) return events def _run(self, events: pd.DataFrame) -> pd.DataFrame: ratios = { "train": 1.0 - (self.valid_split_ratio + self.test_split_ratio), "val": self.valid_split_ratio, "test": self.test_split_ratio, } # Initialize splitter over Sentence events splitter = _transf.SimilaritySplitter( extractor=ns.extractors.text.TfidfEmbedding(), # TODO: parametrize ratios=ratios, threshold=self.threshold, ) if self.stim_event_type == "Audio": # Designed for Accou2023 events = self._split(events, splitter) sentence_split_mapping = ( events.loc[events.type == "Sentence", ["text", "split"]] .set_index("text")["split"] .to_dict() ) events.loc[events.type == "Word", "split"] = events.loc[ events.type == "Word", "sentence" ].map(sentence_split_mapping) # Chunk Audio events based on Word segments events = _transf.chunk_events( events, event_type_to_chunk="Audio", event_type_to_use="Word", min_duration=3.0, max_duration=120.0, ) elif self.stim_event_type == "Keystroke": # Designed for Levy2025Brain # For the Typing dataset, we need to transfer the clusters from the sentences to the # keystroke events if "sentence" not in events.columns: raise ValueError( "Expected 'sentence' column in Keystroke events for similarity-based splitting." ) unique_sentences = ( events.loc[events.type == "Keystroke", "sentence"].dropna().unique() ) # Fake dataset creation to be able to populate the keystroke events later sentence_df = pd.DataFrame( { "text": unique_sentences, "type": "Sentence", "start": 0.0, "duration": 0.0, "timeline": "dummy_tl", } ) # Run similarity splitter on these pseudo-Sentence rows sentence_df = self._split(sentence_df, splitter) # Mapping sentence → split sentence_split_map = sentence_df.set_index("text")["split"].to_dict() events.loc[events.type == "Keystroke", "split"] = events.loc[ events.type == "Keystroke", "sentence" ].map(sentence_split_map) elif self.stim_event_type in ["Sentence", "Word"]: events = self._split(events, splitter) return events
[docs] class PredefinedSplit(_transf.EventsTransform): """Assign train/test labels based on a predefined split, and optionally split train into validation as well. Parameters ---------- event_type : str | None If provided, only split events of this type. test_split_query : str | None If provided, query used to create a test split. col_name : str Column name to use for the created split. valid_split_ratio : float Ratio of the **training set** to use for validation. CAUTION: This is unlike the other splitter (e.g. SklearnSplit) where the validation split is a ratio of the entire dataset. test_random_state : int | None Unused - for compatibility with SklearnSplit. """ event_type: str | None = None test_split_query: str | None col_name: str = "split" valid_split_by: str | None = "timeline" valid_split_ratio: float = 0.2 valid_random_state: int = 33 test_random_state: int | None = None # Unused - for compatibility with SklearnSplit def _run(self, events: pd.DataFrame) -> pd.DataFrame: events = events.copy() if self.event_type is not None: unused_events = events[events.type != self.event_type] events = events[events.type == self.event_type] if self.test_split_query is not None: test_inds = events.query(self.test_split_query).index events[self.col_name] = "train" events.loc[test_inds, self.col_name] = "test" if self.col_name not in events or "test" not in events[self.col_name].values: raise ValueError("Predefined train/test split required.") # Optionally assign random samples of train to validation split if self.valid_split_by is not None: train_events = events[events[self.col_name] == "train"] if self.valid_split_by == "_index": if "_index" not in train_events.columns: train_events = train_events.copy() train_events["_index"] = range(len(train_events)) events.loc[train_events.index, "_index"] = train_events["_index"] train_groups = np.asarray(train_events[self.valid_split_by].dropna().unique()) train, valid = train_test_split( train_groups, test_size=self.valid_split_ratio, random_state=self.valid_random_state, ) if len(valid) == 0: raise ValueError( "Empty validation set, try increasing `valid_split_ratio`." ) split_mapping = { **{k: "train" for k in train}, **{k: "val" for k in valid}, } train_events.loc[:, self.col_name] = train_events[ self.valid_split_by ].replace(split_mapping) events.loc[train_events.index, self.col_name] = train_events[self.col_name] if self.event_type is not None: events = pd.concat([unused_events, events], ignore_index=True, axis=0) return events
[docs] class CropSleepRecordings(_transf.EventsTransform): """Keep up to max_wake_duration_min mins of wake (W) time before and after the first and last sleep events. """ max_wake_duration_min: float = 30.0
[docs] def crop_first_last_wake(self, evs: pd.DataFrame) -> pd.DataFrame: non_wake_inds = evs[evs.stage != "W"].index if len(non_wake_inds) == 0: return evs if non_wake_inds[0] > evs.index[0]: ind = np.where(evs.index == non_wake_inds[0])[0][0] - 1 ind = evs.index[ind] start, duration = evs.loc[ind, ["start", "duration"]] assert isinstance(start, float) assert isinstance(duration, float) new_duration = min(self.max_wake_duration_min * 60.0, duration) evs.loc[ind, "duration"] = new_duration evs.loc[ind, "start"] = start + duration - new_duration if non_wake_inds[-1] < evs.index[-1]: ind = np.where(evs.index == non_wake_inds[-1])[0][0] + 1 ind = evs.index[ind] start, duration = evs.loc[ind, ["start", "duration"]] assert isinstance(start, float) assert isinstance(duration, float) new_duration = min(self.max_wake_duration_min * 60.0, duration) evs.loc[ind, "duration"] = new_duration evs.loc[ind, "stop"] = start + new_duration return evs
def _run(self, events: pd.DataFrame) -> pd.DataFrame: sleep_events = ( events[events.type == "SleepStage"] .groupby("timeline") .apply(self.crop_first_last_wake, include_groups=False) # type: ignore .droplevel(1) .reset_index() ) events = pd.concat( [ events[events.type != "SleepStage"], sleep_events, ], ignore_index=True, axis=0, ).sort_values(by=["timeline", "start"]) return events
[docs] class CropTimelines(_transf.EventsTransform): """Crop neuro timelines. Parameters ---------- event_type : str | None If provided, only crop events of this type. start_offset_s : float | None If provided, crop this offset from the start of the event. max_duration_s : float | None If provided, cap the event at this duration. """ event_type: str | None = None start_offset_s: float | None = None max_duration_s: float | None = None def _run(self, events: pd.DataFrame) -> pd.DataFrame: events = events.copy() if self.event_type is not None: sel_events = events[events.type == self.event_type] else: sel_events = events if self.start_offset_s is not None: sel_events.loc[:, "start"] += self.start_offset_s sel_events.loc[:, "duration"] -= self.start_offset_s if self.max_duration_s is not None: sel_events.loc[:, "duration"] = np.minimum( sel_events.duration.to_numpy(), self.max_duration_s ) events.update(sel_events) if (sel_events["duration"] <= 0).any(): raise ValueError( "Cropping timelines resulted in non-positive duration for some events. " "Check that `start_offset_s` is not larger than the shortest event duration." ) return events
[docs] class AddDefaultEvents(_transf.EventsTransform): """Add default events to a timeline to fill out its duration. E.g., useful for epilepsy recordings where seizures are sparse. Parameters ---------- target_event_type : str Event type to fill out. default_event_type : str Event type to use for the added events. default_event_fields : dict[str, Any] Additional fields to add to the created events, as a dictionary of (field_name, value). """ target_event_type: str default_event_type: str default_event_fields: dict[str, tp.Any] = {} def _run(self, events: pd.DataFrame) -> pd.DataFrame: all_events = [] for tl, group in events.groupby("timeline"): tl_start, tl_stop = group.start.min(), group.stop.max() target_intervals = group.loc[ group.type == self.target_event_type, ["start", "stop"] ] new_intervals = [] start = tl_start for _, interval in target_intervals.iterrows(): if interval.start > start: new_intervals.append((start, interval.start)) start = interval.stop if start < tl_stop: new_intervals.append((start, tl_stop)) new_events = pd.DataFrame(new_intervals, columns=["start", "stop"]) new_events["type"] = self.default_event_type new_events["timeline"] = tl new_events["duration"] = new_events.stop - new_events.start for field, value in self.default_event_fields.items(): new_events[field] = value all_events.extend([group, new_events]) events = pd.concat(all_events, ignore_index=True, axis=0) return events
[docs] class OffsetEvents(_transf.EventsTransform): """Offset selected events by specified amounts. Parameters ---------- query: str Query to select events to offset. start_offset: float | None = None Offset to apply starting from the start of the event. end_offset: float | None = None Offset to apply starting from the end of the event. start_offset_from_end: float | None = None Offset to apply starting from the end of the event. end_offset_from_start: float | None = None Offset to apply starting from the start of the event. """ query: str start_offset: float | None = None end_offset: float | None = None start_offset_from_end: float | None = None end_offset_from_start: float | None = None def model_post_init(self, __context: tp.Any) -> None: super().model_post_init(__context) if self.start_offset is not None and self.start_offset_from_end is not None: raise ValueError("Cannot specify both start_offset and start_offset_from_end") if self.end_offset is not None and self.end_offset_from_start is not None: raise ValueError("Cannot specify both end_offset and end_offset_from_start") if ( self.start_offset is None and self.start_offset_from_end is None and self.end_offset is None and self.end_offset_from_start is None ): msg = "At least one of start_offset, start_offset_from_end, end_offset, or " msg += "end_offset_from_start must be specified" raise ValueError(msg) def _run(self, events: pd.DataFrame) -> pd.DataFrame: for _, group in events.groupby("timeline"): min_tl_start, max_tl_stop = group.start.min(), group.stop.max() sel_events = group.query(self.query) if self.start_offset is not None: new_start = sel_events.start + self.start_offset elif self.start_offset_from_end is not None: new_start = sel_events.stop + self.start_offset_from_end else: new_start = sel_events.start events.loc[sel_events.index, "start"] = np.maximum(min_tl_start, new_start) if self.end_offset is not None: new_stop = sel_events.stop + self.end_offset elif self.end_offset_from_start is not None: new_stop = sel_events.start + self.end_offset_from_start else: new_stop = sel_events.stop events.loc[sel_events.index, "stop"] = np.minimum(max_tl_stop, new_stop) events.loc[sel_events.index, "duration"] = ( events.loc[sel_events.index, "stop"] - events.loc[sel_events.index, "start"] ) if any(events.duration <= 0): msg = ( "Offsetting events resulted in negative or zero duration for some events." ) raise ValueError(msg) return events
class ShuffleTrainingLabels(_transf.EventsTransform): """Randomly permute the label field of training events (diagnostic ablation). Only rows where ``split == "train"`` are modified; validation and test rows are left untouched. Values of the given ``event_field`` are reshuffled uniformly at random among training rows (of the given ``event_types``), preserving the overall class marginal on the training set but destroying the epoch-to-label correspondence. This is intended as a sanity check for pre-training leakage in foundation models: if a model fine-tuned on permuted training labels still performs above chance on the (intact) test labels, the encoder has memorised dataset-level structure during pre-training. Clean fine-tuning lift on genuine features should collapse to chance under this ablation. Must run **after** a split transform that populates the ``split`` column. Parameters ---------- event_field : str Name of the column to permute (e.g. ``"code"`` for LabelEncoder). event_types : str or list of str or None If provided, only permute values among events whose ``type`` is in this list. ``None`` permutes across all training rows regardless of type. random_state : int Seed for reproducible permutation. """ event_field: str event_types: str | list[str] | None = None random_state: int = 33 def _run(self, events: pd.DataFrame) -> pd.DataFrame: if "split" not in events.columns: raise ValueError( "ShuffleTrainingLabels requires a 'split' column. Run a split " "transform (e.g. SklearnSplit, PredefinedSplit) first." ) if self.event_field not in events.columns: raise ValueError( f"event_field {self.event_field!r} not found in events columns." ) events = events.copy() is_train = events["split"] == "train" if self.event_types is not None: types = ( [self.event_types] if isinstance(self.event_types, str) else list(self.event_types) ) mask = is_train & events["type"].isin(types) else: mask = is_train n = int(mask.sum()) if n <= 1: return events rng = np.random.default_rng(self.random_state) values = events.loc[mask, self.event_field].to_numpy().copy() rng.shuffle(values) events.loc[mask, self.event_field] = values return events