# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Segments: lightweight time-window views into a shared event store.
Create segments with :func:`list_segments`, then pass them to
:class:`~neuralset.dataloader.SegmentDataset` to feed extractors via a
PyTorch ``DataLoader``.
"""
from __future__ import annotations
import collections
import dataclasses
import itertools
import logging
import math
import typing as tp
import uuid
import weakref
import numpy as np
import pandas as pd
from neuralset.events.etypes import Event
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# _EventStore
# ---------------------------------------------------------------------------
@dataclasses.dataclass
class _EventBucket:
"""Sorted arrays for one event-type group within a timeline."""
starts: np.ndarray # sorted event start times
stops: np.ndarray # start + duration (parallel to starts)
events: list[Event] # parallel to starts/stops
max_dur: float # max(duration) in this bucket
@classmethod
def from_events(
cls, events: list[Event], min_bucket_size: int = 50
) -> list[_EventBucket]:
"""Build sorted buckets from events.
Each bucket tracks its own ``max_dur`` so searchsorted windows
stay tight (e.g. 0.3 s for Words vs 500 s for Stimulus). Types
with fewer than ``min_bucket_size`` events are merged into one
bucket since their scan cost is negligible regardless of window.
"""
by_type: dict[str, list[Event]] = collections.defaultdict(list)
for e in events:
by_type[e.type].append(e)
groups: dict[str, list[Event]] = {}
merged: list[Event] = []
for typ, typed in by_type.items():
if len(typed) >= min_bucket_size:
groups[typ] = typed
else:
merged.extend(typed)
if merged:
groups[""] = merged
buckets: list[_EventBucket] = []
for typed in groups.values():
typed.sort(key=lambda e: e.start)
starts = np.array([e.start for e in typed])
durations = np.array([e.duration for e in typed])
buckets.append(
cls(
starts=starts,
stops=starts + durations,
events=typed,
max_dur=float(durations.max()) if len(durations) else 0.0,
)
)
return buckets
def overlapping(self, start: float, stop: float) -> list[Event]:
"""Return events overlapping the ``[start, stop)`` window."""
events = self.events
# O(N) linear scan — faster for small buckets due to no Python overhead
if len(self.starts) <= 150:
mask = (self.starts < stop) & (self.stops > start)
return list(itertools.compress(events, mask))
# O(log N + k) searchsorted — widen lo by max_dur so we don't
# miss long events that start early but extend into the window
lo = int(np.searchsorted(self.starts, start - self.max_dur))
hi = int(np.searchsorted(self.starts, stop))
if lo >= hi:
return []
# Filter candidates whose actual end time overlaps the query start
mask = self.stops[lo:hi] > start
return list(itertools.compress(events[lo:hi], mask))
class _EventStore:
"""Shared, read-only container for events and a fast overlap index.
Built once from a DataFrame via :meth:`from_dataframe`. Segments hold
a reference (strong in-process, UUID across pickle) and resolve
overlapping events through :meth:`overlapping`.
"""
# process-local registry; segments look up stores by UUID after unpickling
_REGISTRY: tp.ClassVar[weakref.WeakValueDictionary[uuid.UUID, _EventStore]] = (
weakref.WeakValueDictionary()
)
def __init__(
self,
events: list[Event],
timeline_index: dict[str, list[_EventBucket]],
) -> None:
self._id = uuid.uuid4()
self._events = events # flat list for trigger lookup (positional)
self._timeline_index = timeline_index
_EventStore._REGISTRY[self._id] = self
def __setstate__(self, state: dict[str, tp.Any]) -> None:
self.__dict__.update(state)
_EventStore._REGISTRY[self._id] = self
# --- construction -------------------------------------------------------
@classmethod
def from_dataframe(cls, df: pd.DataFrame) -> _EventStore:
"""Build a store from a validated events DataFrame."""
from neuralset.events import utils as ev_utils
events = ev_utils.extract_events(df)
by_timeline: dict[str, list[Event]] = collections.defaultdict(list)
for event in events:
by_timeline[event.timeline].append(event)
timeline_index: dict[str, list[_EventBucket]] = {}
for tl, tl_events in by_timeline.items():
timeline_index[tl] = _EventBucket.from_events(tl_events)
return cls(events, timeline_index)
# --- public lookup ------------------------------------------------------
def overlapping(self, start: float, duration: float, timeline: str) -> list[Event]:
"""Events overlapping [start, start + duration) on *timeline*."""
buckets = self._timeline_index.get(timeline)
if buckets is None:
return []
stop = start + duration
result: list[Event] = []
for bucket in buckets:
result.extend(bucket.overlapping(start, stop))
return result
def __getitem__(self, pos: int) -> Event:
return self._events[pos]
@property
def timelines(self) -> list[str]:
return list(self._timeline_index)
# ---------------------------------------------------------------------------
# Segment
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass
class Segment:
"""A time window on a single :term:`timeline`, backed by a shared event store.
Created by :func:`list_segments` (or via
:class:`~neuralset.dataloader.Segmenter`) — not meant to be
instantiated directly. Each segment references an internal event
store that holds all events; overlapping events are resolved on
access via :attr:`ns_events`.
Segments are **mutable**: changing :attr:`start` or :attr:`duration`
changes which events :attr:`ns_events` returns (useful for jittering).
The underlying event data is snapshotted at creation time — later
modifications to the source DataFrame have no effect.
"""
start: float
duration: float
timeline: str
# position in the _EventStore flat event list (for trigger lookup)
_trigger_idx: int | None = None
# UUID for cross-process store lookup (pickled)
_store_id: uuid.UUID = dataclasses.field(repr=False, default_factory=uuid.uuid4)
# strong in-process ref, not pickled
_store_ref: _EventStore | None = dataclasses.field(default=None, repr=False)
# memoized ns_events: invalidated when (start, duration, timeline) changes
_cache_key: tuple[float, float, str] | None = dataclasses.field(
default=None,
repr=False,
compare=False,
)
_cached_events: list[Event] = dataclasses.field(
default_factory=list,
repr=False,
compare=False,
)
# Segments don't own the store — SegmentDataset holds the strong
# references that keep stores alive across pickle (see _event_stores).
@property
def _store(self) -> _EventStore:
if self._store_ref is not None:
return self._store_ref
try:
store = _EventStore._REGISTRY[self._store_id]
except KeyError:
raise RuntimeError(
"_EventStore not found in this process. This segment was "
"likely unpickled outside of a SegmentDataset/DataLoader."
) from None
self._store_ref = store # keep strong ref from now on
return store
@property
def stop(self) -> float:
"""End time of the segment.
Returns
-------
float
The stop time, calculated as ``start + duration``.
"""
return self.start + self.duration
@property
def trigger(self) -> Event | None:
"""The event that triggered this segment, or ``None``."""
if self._trigger_idx is None:
return None
return self._store[self._trigger_idx]
@property
def ns_events(self) -> list[Event]:
"""Events overlapping this segment's ``[start, start + duration)`` window.
The result is cached and automatically invalidated when
:attr:`start`, :attr:`duration`, or :attr:`timeline` changes.
"""
key = (self.start, self.duration, self.timeline)
if self._cache_key != key:
self._cached_events = self._store.overlapping(
self.start, self.duration, self.timeline
)
self._cache_key = key
return self._cached_events
@property
def events(self) -> pd.DataFrame:
"""Events occurring within the segment.
Returns
-------
pd.DataFrame
DataFrame containing all events in this segment, with the
original indices preserved.
"""
ns_events = self.ns_events
evts = [e.to_dict() for e in ns_events]
for ev in evts:
ev.pop("Index", None)
indices = [e._index for e in ns_events]
df = pd.DataFrame(data=evts, index=indices)
return df.sort_index()
def _to_extractor(self) -> dict[str, tp.Any]:
"""Dict that can be passed as ``**kwargs`` to extractors.
Returns
-------
dict
Keys: ``start``, ``duration``, ``events``, ``trigger``.
"""
return {
"start": self.start,
"duration": self.duration,
"events": self.ns_events,
"trigger": self.trigger,
}
[docs]
def copy(self, offset: float = 0.0, duration: float | None = None) -> Segment:
"""Create a copy of the current segment with optional offset and duration."""
return Segment(
start=self.start + offset,
duration=duration if duration is not None else self.duration,
timeline=self.timeline,
_trigger_idx=self._trigger_idx,
_store_id=self._store_id,
_store_ref=self._store_ref,
)
def __getstate__(self) -> dict[str, tp.Any]:
state = self.__dict__.copy()
# _store_ref is large (all events); segments pickle only the
# UUID and look up the store via _EventStore._REGISTRY.
state["_store_ref"] = None
return state
def __setstate__(self, state: dict[str, tp.Any]) -> None:
self.__dict__.update(state)
# Re-establish strong ref if the store is already in this process.
if self._store_ref is None:
self._store_ref = _EventStore._REGISTRY.get(self._store_id)
# ---------------------------------------------------------------------------
# Segment creation
# ---------------------------------------------------------------------------
def list_segments(
events: pd.DataFrame,
triggers: pd.Series,
*,
start: float = 0.0,
duration: float | None = None,
stride: float | None = None,
stride_drop_incomplete: bool = True,
) -> list[Segment]:
"""Create a list of segments based on events, sliding windows, or both.
Parameters
----------
events : pd.DataFrame
DataFrame containing events, must be normalized first using :func:`~neuralset.events.standardize_events`.
triggers : pd.Series
Boolean mask of events to use for defining segments.
start : float, optional
Start offset (in seconds) of segments relative to reference events or stride.
Use negative values to start before the reference. Default is 0.0.
duration : float, optional
Duration (in seconds) of each segment. If None, defaults to the duration of each event.
Required when using strided windows. Default is None.
stride : float, optional
Step size (in seconds) for sliding window segmentation. Default is None.
stride_drop_incomplete : bool, optional
If True and stride is not None, drop segments not fully contained within
the valid time range. Default is True.
Returns
-------
list of Segment
List of segments with populated :attr:`ns_events` field and :attr:`trigger` containing the event
that triggered the segment.
Notes
-----
Two segmentation modes are supported:
1. Single window: each event specified by `triggers` yields a single segment.
2. Sliding window: for each event specified by `triggers`, create strided
windows of duration `duration` and step size `stride`.
Examples
--------
Single window segmentation:
>>> seg = list_segments(events, triggers=events.type == "Image", start=-0.5, duration=2.0)
Sliding window segmentation:
>>> seg = list_segments(events, triggers=events.type == "Meg", stride=1, duration=3)
"""
start = float(start)
if duration is not None:
if isinstance(duration, np.ndarray):
duration = duration.item()
duration = float(duration)
if stride is not None:
stride = float(stride)
if stride is not None and duration is None:
raise RuntimeError("duration must be provided for strided windows")
if not hasattr(events, "stop"):
raise ValueError("Run standardize_events on the DataFrame first")
if not isinstance(triggers, pd.Series):
raise TypeError(
f"triggers must be a boolean pd.Series, got {type(triggers).__name__}"
)
store = _EventStore.from_dataframe(events)
trigger_df = events.loc[triggers]
if not len(trigger_df):
raise ValueError("Empty trigger events")
df_to_pos = events.index.get_indexer(trigger_df.index)
seg_starts: list[tp.Any] = []
seg_durations: list[tp.Any] = []
trigger_positions: list[tp.Any] = []
trigger_timelines: list[tp.Any] = []
if stride is None:
seg_starts = (trigger_df["start"] + start).tolist()
seg_durations = (
trigger_df["duration"].tolist()
if duration is None
else [duration] * len(trigger_df)
)
trigger_positions = df_to_pos.tolist()
trigger_timelines = trigger_df["timeline"].tolist()
else:
assert duration is not None
for i, trig in enumerate(trigger_df.itertuples()):
s, d = _prepare_strided_windows(
float(trig.start) + start, # type: ignore
float(trig.stop), # type: ignore
stride,
duration,
drop_incomplete=stride_drop_incomplete,
)
seg_starts.extend(s)
seg_durations.extend(d)
trigger_positions.extend([df_to_pos[i]] * len(s))
trigger_timelines.extend([str(trig.timeline)] * len(s))
segments: list[Segment] = []
for s, d, trig_idx, tl in zip(
seg_starts, seg_durations, trigger_positions, trigger_timelines
):
segments.append(
Segment(
start=s,
duration=d,
timeline=tl,
_trigger_idx=trig_idx,
_store_id=store._id,
_store_ref=store,
)
)
return segments
# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------
def _prepare_strided_windows(
start: float,
stop: float,
stride: float,
duration: float,
drop_incomplete: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
"""Prepare parameters for strided sliding windows.
Parameters
----------
start : float
Start time of the overall window.
stop : float
Stop time of the overall window.
stride : float
Step size between consecutive windows.
duration : float
Duration of each window.
drop_incomplete : bool, optional
If True, drop windows that are not fully contained within (start, stop).
This matches the behavior of :code:`mne.events_from_annotations` with
:code:`chunk_duration`. Default is True.
Returns
-------
starts : np.ndarray
Array of start times for each window.
durations : np.ndarray
Array of durations for each window (all equal to the input duration).
Notes
-----
When :code:`drop_incomplete=True`, the effective stop time is adjusted to
:code:`stop - duration` to ensure the last window fits completely.
"""
effective_stop = (stop - duration) if drop_incomplete else stop
span = effective_stop - start
if span < -1e-10:
msg = f"Dropping all windows as asking window duration {duration} on "
msg += f"shorter overall window duration {stop - start} and {drop_incomplete=}"
raise RuntimeError(msg)
ratio = span / stride
if drop_incomplete:
n = math.floor(ratio + 1e-9) + 1
else:
n = math.ceil(ratio - 1e-9)
n = max(0, n)
starts = start + np.arange(n) * stride
durations = np.full_like(starts, fill_value=duration)
return starts, durations
def find_enclosed(df: pd.DataFrame, start: float, duration: float) -> pd.Series:
"""Find events fully enclosed within a time window.
Parameters
----------
df : pd.DataFrame
DataFrame containing events with 'start' and 'duration' columns.
start : float
Start time of the enclosing window.
duration : float
Duration of the enclosing window.
Returns
-------
pd.Series
Series containing indices of events that are fully enclosed within
the specified time window.
Notes
-----
An event is considered enclosed if both its start and end times fall within
the window: :code:`start <= event_start` and :code:`event_end <= start + duration`.
Examples
--------
>>> enclosed = find_enclosed(events, start=5.0, duration=10.0)
>>> events.loc[enclosed] # Get fully enclosed events
"""
estart = np.array(df.start)
estop = estart + np.array(df.duration)
is_enclosed = np.logical_and(estart >= start, estop <= start + duration)
return pd.Series(df.index[is_enclosed])
def find_overlap(
events: pd.DataFrame,
triggers: pd.Series | None = None,
*,
start: float = 0.0,
duration: float | None = None,
) -> pd.Series:
"""Find events that overlap with a reference time window or events.
Parameters
----------
events : pd.DataFrame
DataFrame containing events.
triggers : pd.Series, optional
If provided, reference events to check overlap against. If None, uses
the absolute time window defined by start and duration. Default is None.
start : float, optional
If triggers is None, absolute start time of the reference window. If triggers is
provided, offset relative to reference events. Default is 0.0.
duration : float, optional
Duration of the reference window. Required if triggers is None, otherwise
can be None to use reference event durations. Default is None.
Returns
-------
pd.Series
Series containing indices of events that overlap with the reference.
Raises
------
AssertionError
If triggers is None but duration is not provided, or if events contains multiple
timelines when triggers is None.
Notes
-----
Overlap is detected if any of these conditions are met:
- Event starts within the reference window
- Event ends within the reference window
- Event fully encloses the reference window
Examples
--------
Find overlap with absolute time window:
>>> overlapping = find_overlap(events, start=5.0, duration=10.0)
Find overlap with reference events:
>>> ref_events = events.type == "Stimulus"
>>> overlapping = find_overlap(events, triggers=ref_events, start=-1.0, duration=3.0)
"""
if triggers is None:
assert duration is not None
assert events.timeline.nunique() == 1
has_overlap = (events.start >= start) & (events.start < start + duration)
has_overlap |= (events.start + events.duration > start) & (
events.start + events.duration <= start + duration
)
has_overlap |= (events.start <= start) & (
events.start + events.duration >= start + duration
)
return pd.Series(events.index[has_overlap])
else:
sel: list[int] = []
for segment in list_segments(
events,
triggers=triggers,
start=start,
duration=duration,
):
sel.extend(e._index for e in segment.ns_events) # type: ignore[misc]
return pd.Series(sel)
def find_incomplete_segments(
segments: tp.Sequence[Segment], event_types: tp.Sequence[tp.Type[Event]]
) -> list[int]:
"""Find segments missing required event types.
Returns indices of segments that do not contain at least one event of each
specified event type (or their subclasses).
Parameters
----------
segments : sequence of Segment
Segments to check for completeness.
event_types : sequence of Event classes
Event types that must be present in each segment.
Returns
-------
list of int
Sorted list of indices of segments that are missing at least one
required event type.
Warnings
--------
Logs a warning for each event type indicating how many segments are incomplete.
Examples
--------
>>> from neuralset.events.etypes import Image, Audio
>>> incomplete = find_incomplete_segments(segments, [Image, Audio])
>>> complete_segments = [s for i, s in enumerate(segments) if i not in incomplete]
Notes
-----
An event type's subclasses also count as valid. For example, if checking for
:class:`~neuralset.events.Image`, segments containing :class:`~neuralset.events.NaturalImage`
would also be considered complete.
"""
all_invalid_indices = set()
for event_type in event_types:
invalid_indices = set()
subclasses = [
name for name, cls in Event._CLASSES.items() if issubclass(cls, event_type)
]
for i, segment in enumerate(segments):
if not any(e.type in subclasses for e in segment.ns_events):
invalid_indices.add(i)
if invalid_indices:
msg = f"{len(invalid_indices)} segments out of {len(segments)} did not contain valid events for event type {event_type}"
logger.warning(msg)
all_invalid_indices.update(invalid_indices)
return sorted(list(all_invalid_indices))