Source code for neuralset.events.transforms.chunking

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import numpy as np
import pandas as pd

from .. import etypes as ev
from ..study import EventsTransform

Tiling = tp.Literal["equal", "max"]


class _Section(tp.NamedTuple):
    """Contiguous ``[start, stop)`` span, label-homogeneous by construction."""

    start: float
    stop: float

    @property
    def duration(self) -> float:
        return self.stop - self.start

    def tile(
        self, min_duration: float, max_duration: float, tiling: Tiling
    ) -> list["_Section"]:
        """Partition into sub-sections per ``tiling`` strategy.

        ``"equal"``: maximally equal-sized chunks, each in ``[min_duration, max_duration]``.
        Raises if ``2 * min_duration > max_duration`` (bounds unsatisfiable) or
        if the section is shorter than ``min_duration``.

        ``"max"``: chunks of exactly ``max_duration`` plus a trailing partial
        covering the remainder. ``min_duration`` is unused here.
        """
        tol = ChunkEvents._SAMPLE_GRID_TOL
        if tiling == "equal":
            if 2 * min_duration > max_duration + tol:
                raise ValueError(
                    f"min_duration={min_duration} must be <= "
                    f"max_duration/2={max_duration / 2} so uniform tiling "
                    f"always yields chunks in [min_duration, max_duration]"
                )
            if self.duration < min_duration - tol:
                raise ValueError(
                    f"section [{self.start:.3f}, {self.stop:.3f}) (duration "
                    f"{self.duration:.3f}s) cannot tile into "
                    f"[{min_duration}, {max_duration}]. Set ``min_duration`` "
                    f"below the shortest section, merge short same-split runs "
                    f"in ``event_type_to_split_by`` before chunking, or use "
                    f'``tiling="max"`` to drop short trailing pieces instead.'
                )
            n = max(1, int(np.ceil(self.duration / max_duration - tol)))
            dur = self.duration / n
            return [
                _Section(self.start + i * dur, self.start + (i + 1) * dur)
                for i in range(n)
            ]
        # tiling == "max"
        if not np.isfinite(max_duration):
            return [self]  # post-filter handles ``duration < min_duration``
        n_full = int(self.duration / max_duration + tol)
        out = [
            _Section(self.start + i * max_duration, self.start + (i + 1) * max_duration)
            for i in range(n_full)
        ]
        if self.duration - n_full * max_duration > tol:
            out.append(_Section(self.start + n_full * max_duration, self.stop))
        return out


def _build_sections(
    event: ev.BaseSplittableEvent,
    use_rows: pd.DataFrame | None = None,
    allow_sample_leakage: bool = False,
) -> list[_Section]:
    """Partition ``event`` into label-homogeneous sections.

    Boundaries sit at the midpoint of the silence gap between consecutive
    rows of different ``split``. First section absorbs any leading silence;
    last extends to ``event.start + event.duration``. Returns a single
    section covering the whole event when ``use_rows`` is None/empty.

    Raises if a split transition leaves less than ``1 / frequency`` of
    silence (unless ``allow_sample_leakage=True``): ``event._split``'s
    nearest-grid snap of the midpoint isn't guaranteed to stay in the gap
    below that, so a labeled sample could end up in the wrong chunk.
    """
    tol = ChunkEvents._SAMPLE_GRID_TOL
    e_start, e_stop = event.start, event.start + event.duration
    if use_rows is None or len(use_rows) == 0:
        return [_Section(e_start, e_stop)]
    mask = (use_rows.start + use_rows.duration > e_start + tol) & (
        use_rows.start < e_stop - tol
    )
    rows = use_rows[mask].sort_values("start").reset_index(drop=True)
    if rows.empty:
        return [_Section(e_start, e_stop)]
    freq = float(event.frequency)
    min_gap = 1.0 / freq if freq > 0 else 0.0
    splits = rows.split.astype(str)
    boundaries = [e_start]
    for i in range(1, len(rows)):
        if splits.iloc[i] == splits.iloc[i - 1]:
            continue
        prev_stop = rows.iloc[i - 1].start + rows.iloc[i - 1].duration
        next_start = rows.iloc[i].start
        if not allow_sample_leakage and next_start - prev_stop < min_gap - tol:
            raise ValueError(
                f"label leakage: split transition at {prev_stop:.3f}s "
                f"({splits.iloc[i - 1]!r}{splits.iloc[i]!r}) must leave "
                f"at least 1/frequency = {min_gap:.3f}s of silence at {freq} Hz. "
                f"Pass ``allow_sample_leakage=True`` to accept up to 1 sample "
                f"of mislabeling at the boundary."
            )
        boundaries.append((prev_stop + next_start) / 2)
    boundaries.append(e_stop)
    return [
        _Section(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)
    ]


[docs] class ChunkEvents(EventsTransform): """Chunk long events into shorter events. Typical use: keep long recordings under a deep-learning model's memory budget (e.g. Wav2Vec). Parameters ---------- event_type_to_chunk : str Splittable event type to chunk. Any :class:`~neuralset.events.etypes.BaseSplittableEvent` subclass (Audio, Video, Meg, Eeg, Fmri, ...). max_duration : float, default=``np.inf`` Upper bound on chunk duration in seconds. min_duration : float, default=0.0 Lower bound on chunk duration. Behavior when impossible depends on ``tiling`` (see below). tiling : {"max", "equal"}, default ``"max"`` How each section is sub-divided: - ``"max"``: emit chunks of exactly ``max_duration`` until the section is exhausted; the trailing partial chunk is dropped iff its duration is ``< min_duration``. - ``"equal"``: equal-sized chunks, each in ``[min_duration, max_duration]``. Requires ``2 * min_duration <= max_duration``. Raises if a section is shorter than ``min_duration``. event_type_to_split_by : str, optional Align chunk boundaries with train/val/test labels carried by another event type's ``split`` column, to avoid label leakage at split transitions. When set, chunk boundaries follow same-``split`` runs and each run is sub-tiled per ``tiling``. allow_sample_leakage : bool, default=False Only relevant when ``event_type_to_split_by`` is set. If True, accept up to 1 sample of mislabeling at split transitions with sub-sample silence gaps (e.g. coarse-TR Fmri); otherwise raise. Invariants ---------- - Every emitted chunk has ``duration >= min_duration``. - Every emitted chunk is label-homogeneous when ``event_type_to_split_by`` is set. - ``"equal"`` is lossless: concatenation reconstructs the input event. - ``"max"`` may silently drop sections/trailing pieces shorter than ``min_duration``. Raises ------ ValueError - ``tiling="equal"`` and a same-``split`` run shorter than ``min_duration`` (cannot tile without losing labeled data — switch to ``tiling="max"`` to drop short pieces instead). - Two consecutive differently-labeled runs are less than one sample apart (cannot separate without label leakage). Examples -------- Simple chunking (each ``x`` = one sample; sound sampled at 1 Hz):: input: max_duration: 4 events: sound: [x x x x x x x x x x x x x] # 13 s out (tiling="max"): # tile with max duration + trail events: sound1: [x x x x] sound2: [x x x x] sound3: [x x x x] sound4: [x] # short trailing chunk out (tiling="equal"): # tile with ~ same length events: sound1: [x x x] sound2: [x x x] sound3: [x x x x] # 3.25 s ideal, rounded to whole sample sound4: [x x x] With train/test split labels:: input: max_duration: 4 event_type_to_split_by: Word events: sound: [x x x x x x x x x x x x x] # 13 s word: 1 1 1 - - 2 2 2 2 2 2 2 2 # 1=test, 2=train, -=silence out (tiling="equal"): # split-aligned, then sub-tiled events: sound1: [x x x x] # test run sound2: [x x x] # train run, 3 equal chunks sound3: [x x x] sound4: [x x x] """ event_type_to_chunk: str event_type_to_split_by: str | None = None min_duration: float = 0.0 max_duration: float = np.inf tiling: Tiling = "max" allow_sample_leakage: bool = False # Slack for sample-aligned float compares; absorbs float noise without # masking real mismatches so long as ``event.frequency << 1 / tol`` # (i.e. tol << one sample period). Enforced per-event in ``_chunk_timeline``. _SAMPLE_GRID_TOL: tp.ClassVar[float] = 1e-6 def model_post_init(self, log__: object) -> None: super().model_post_init(log__) cls = ev.Event._CLASSES.get(self.event_type_to_chunk) if cls is None or not issubclass(cls, ev.BaseSplittableEvent): splittable = [ n for n, c in ev.Event._CLASSES.items() if issubclass(c, ev.BaseSplittableEvent) ] raise ValueError( f"{self.event_type_to_chunk!r} is not a splittable event type. " f"Use one of {splittable}" ) if ( self.tiling == "equal" and 2 * self.min_duration > self.max_duration + self._SAMPLE_GRID_TOL ): raise ValueError( f"min_duration={self.min_duration} must be <= " f"max_duration/2={self.max_duration / 2} for tiling='equal' " f"so chunks always land in [min_duration, max_duration]" ) def _run(self, events: pd.DataFrame) -> pd.DataFrame: if self.event_type_to_split_by is not None and "split" not in events.columns: raise RuntimeError("Events must have a split column") chunked = [self._chunk_timeline(df) for _, df in events.groupby("timeline")] return pd.concat(chunked).reset_index(drop=True) def _chunk_timeline(self, df: pd.DataFrame) -> pd.DataFrame: """Return one timeline's events with target rows replaced by their chunks.""" to_chunk = df.type == self.event_type_to_chunk if not any(to_chunk): return df use_rows = ( df.loc[df.type == self.event_type_to_split_by] if self.event_type_to_split_by is not None else None ) rows = df.loc[to_chunk] added: list[dict] = [] for row in rows.itertuples(index=False): event = ev.BaseSplittableEvent.from_dict(row) if event.frequency * self._SAMPLE_GRID_TOL >= 1: raise ValueError( f"event frequency {event.frequency} Hz is too high for the " f"sample-grid tolerance {self._SAMPLE_GRID_TOL}s " f"(>= 1 sample period); boundary compares would mask real " f"1-sample mismatches." ) sections = _build_sections(event, use_rows, self.allow_sample_leakage) tiled = [ sub for s in sections for sub in s.tile(self.min_duration, self.max_duration, self.tiling) ] rel_tps = [t.start - event.start for t in tiled[1:]] pieces = event._split(rel_tps) added.extend( p.to_dict() for p in pieces if p.duration >= self.min_duration - self._SAMPLE_GRID_TOL ) return pd.concat([df.drop(rows.index), pd.DataFrame(added)])