Source code for neuralset.events.transforms.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import bisect
import hashlib
import logging
import random
import re
import typing as tp
import unicodedata
from dataclasses import dataclass
from functools import lru_cache

import numpy as np
import pandas as pd

from neuralset import segments as _segs
from neuralset import utils as _ns_utils

from .. import etypes as ev

logger = logging.getLogger(__name__)


MISSING_SENTENCE = "# MISSING SENTENCE #"


# ---------------------------------------------------------------------------
# text helpers
# ---------------------------------------------------------------------------


@lru_cache
def parse_text(text: str, language: str = "") -> tp.Any:
    nlp = _ns_utils.get_spacy_model(language=language)
    return nlp(text)


def _extract_sentences(events) -> list[ev.Sentence]:
    """Extract sentence events from the words with sentence annotations"""
    wtypes = ev.EventTypesHelper("Word")
    words_df = events.loc[events.type.isin(wtypes.names), :]
    sentences = []
    words: list[tp.Any] = []
    eps = 1e-6
    for k, word in enumerate(words_df.itertuples(index=False)):
        if words and words[-1].timeline == word.timeline:
            if word.start < words[-1].start:
                raise ValueError(
                    f"Words are not sorted within a timeline ({words!r} and then {word!r}"
                )
        sentence_end = False
        if k == len(words_df) - 1:  # last word event
            sentence_end = True
            words.append(word)
        if words:
            sentence_end |= words[-1].timeline != word.timeline
            sentence_end |= word.sentence != words[-1].sentence
            sentence_end |= word.sentence_char <= words[-1].sentence_char
            if sentence_end:
                w0 = words[0]
                text = w0.sentence
                if not (isinstance(text, str) and text):
                    text = MISSING_SENTENCE
                sentences.append(
                    ev.Sentence(
                        start=w0.start - eps,
                        duration=words[-1].start
                        + words[-1].duration
                        - w0.start
                        + 2 * eps,
                        timeline=w0.timeline,
                        text=text,
                    )
                )
                words = []
        words.append(word)
    return sentences


_LIGATURES = str.maketrans({"œ": "oe", "Œ": "OE", "æ": "ae", "Æ": "AE"})


def _normalize_with_positions(text: str) -> tuple[str, list[int]]:
    """Normalize *text* (lowercase, decompose ligatures, strip accents)
    and return a mapping from each normalized character index to its
    original position in *text*.
    """
    out: list[str] = []
    orig: list[int] = []
    for i, ch in enumerate(text.translate(_LIGATURES)):
        for c in unicodedata.normalize("NFKD", ch):
            if not unicodedata.combining(c):
                out.append(c)
                orig.append(i)
    return "".join(out), orig


class TextWordMatcher:
    """Match annotated words to character positions in a spaCy-parsed text.

    Two-phase strategy:
    1. Token-level Levenshtein alignment (spaCy tokens vs input words).
    2. Character-level fallback for unmatched words, using the raw text
       spans between surrounding matched tokens.

    After matching, each word gets ``sentence``, ``sentence_char``, and
    ``text_char`` when a reliable match is found.  Words sandwiched
    between neighbours in the same sentence inherit the sentence label.
    """

    _PUNCT_RE = re.compile(r"^[\W_]+|[\W_]+$", re.UNICODE)

    def __init__(self, text: str, language: str = "") -> None:
        self.doc = parse_text(text, language=language)
        self.text = text
        self.tokens: list[tp.Any] = [tok for sent in self.doc.sents for tok in sent]

    @staticmethod
    def normalize(word: str) -> str:
        """Lowercase, strip accents/ligatures, and strip leading/trailing punctuation."""
        text, _ = _normalize_with_positions(word.lower())
        return TextWordMatcher._PUNCT_RE.sub("", text)

    def match(self, words: tp.Sequence[str]) -> list[dict[str, tp.Any]]:
        token_strs = [self.normalize(t.text) for t in self.tokens]
        tok_matched, word_matched = _ns_utils.match_list(
            token_strs, [self.normalize(w) for w in words]
        )
        info: list[dict[str, tp.Any]] = [{"_word": w} for w in words]
        for ti, wi in zip(tok_matched, word_matched):
            info[wi]["_tok"] = ti

        self._resolve_gaps(info)
        self._finalize(info)
        return info

    # -- char-level fallback ------------------------------------------------

    def _resolve_gaps(self, info: list[dict[str, tp.Any]]) -> None:
        """Find contiguous unmatched runs and resolve each via char-level matching."""
        gaps: list[tuple[int, int, int | None, int | None]] = []
        prev_tok: int | None = None
        gap_start: int | None = None
        for k, i in enumerate(info):
            if "_tok" in i:
                if gap_start is not None:
                    gaps.append((gap_start, k, prev_tok, i["_tok"]))
                    gap_start = None
                prev_tok = i["_tok"]
            elif gap_start is None:
                gap_start = k
        if gap_start is not None:
            gaps.append((gap_start, len(info), prev_tok, None))

        for gs, ge, left_tok, right_tok in gaps:
            self._resolve_one_gap(info[gs:ge], left_tok, right_tok)

    def _resolve_one_gap(
        self,
        gap: list[dict[str, tp.Any]],
        left_tok: int | None,
        right_tok: int | None,
    ) -> None:
        """Character-level Levenshtein fallback for one gap of unmatched words."""
        char_start = 0
        if left_tok is not None:
            t = self.tokens[left_tok]
            char_start = t.idx + len(t)
        char_end = len(self.text)
        if right_tok is not None:
            char_end = self.tokens[right_tok].idx

        raw = self.text[char_start:char_end].lower()
        subtext, orig_pos = _normalize_with_positions(raw)
        concat = " ".join(self.normalize(w["_word"]) for w in gap)
        if not subtext or not concat:
            return
        sub_match, concat_match = _ns_utils.match_list(subtext, concat)

        norm_lens = [len(self.normalize(w["_word"])) for w in gap]
        char_to_word = [
            (wi, ci) for wi, nlen in enumerate(norm_lens) for ci in range(nlen + 1)
        ]
        for ti, ci in zip(sub_match, concat_match):
            wi, charnum = char_to_word[ci]
            # character position in original text
            gap[wi].setdefault("_votes", []).append(char_start + orig_pos[ti] - charnum)

        tok_slice = self.tokens[left_tok:right_tok]
        for w, nlen in zip(gap, norm_lens):
            votes: list[int] | None = w.pop("_votes", None)
            if not votes:
                continue
            best = max(votes, key=votes.count)
            if votes.count(best) / max(nlen, 1) <= 0.5:
                logger.warning(
                    "Ignoring unreliable matching for '%s' in '%s'",
                    w["_word"],
                    subtext,
                )
                continue
            found = self.text[best : best + len(w["_word"])]
            if self.normalize(w["_word"]) != self.normalize(found):
                logger.warning(
                    "Approximately matched annotated %r with %r in text",
                    w["_word"],
                    found,
                )
            if not tok_slice:
                continue
            ind = bisect.bisect_right(tok_slice, best, key=lambda t: t.idx)
            ind = max(ind - 1, 0)
            nearest = tok_slice[ind]
            w["text_char"] = best
            w["sentence"] = nearest.sent.text_with_ws
            w["sentence_char"] = best - nearest.sent[0].idx

    # -- finalization -------------------------------------------------------

    def _finalize(self, info: list[dict[str, tp.Any]]) -> None:
        """Convert internal keys to output keys and fill sentence gaps."""
        for i in info:
            i.pop("_word")
            tok_idx = i.pop("_tok", None)
            if tok_idx is not None:
                tok = self.tokens[tok_idx]
                i["text_char"] = tok.idx
                i["sentence"] = tok.sent.text_with_ws
                i["sentence_char"] = tok.idx - tok.sent[0].idx

        prev_sent: str | None = None
        pending: list[dict[str, tp.Any]] = []
        for i in info:
            sent = i.get("sentence")
            if sent is None:
                pending.append(i)
                continue
            if prev_sent == sent:
                for p in pending:
                    p["sentence"] = sent
            pending = []
            prev_sent = sent


def _merge_sentences(
    sentences: list[ev.Sentence],
    min_duration: float | None = None,
    min_words: int | None = None,
) -> list[list[ev.Sentence]]:
    """Merge consecutive sequences into groups so that there is a span of
    at least min_duration between the start of each group
    """
    out: list[list[ev.Sentence]] = []
    for s in sentences:
        new = True
        if out:
            if min_duration is not None:
                new &= s.start - out[-1][0].start >= min_duration
            if min_words is not None:
                new &= sum(len(s.text.split()) for s in out[-1]) >= min_words
        if not new:
            new |= out[-1][-1].timeline != s.timeline
        if new:
            out.append([s])
        else:
            out[-1].append(s)
    return out


# ---------------------------------------------------------------------------
# chunking helpers
# ---------------------------------------------------------------------------


[docs] def chunk_events( events: pd.DataFrame, event_type_to_chunk: tp.Literal["Audio", "Video"], event_type_to_use: str | None = None, min_duration: float | None = None, max_duration: float = np.inf, ): """ Split events into smaller chunks. If event_type_to_use is None, the events are chunked into chunks of max_duration. If event_type_to_use is not None, the events are chunked based on the train/val/test splits of the event_type_to_use, ensuring that each chunk has duration between min_duration and max_duration. """ added_events: list[dict] = [] dropped_rows: list[int] = [] ns_event_type_to_chunk = ev.Event._CLASSES.get(event_type_to_chunk) if ns_event_type_to_chunk is None or not hasattr(ns_event_type_to_chunk, "_split"): raise ValueError(f"Event type {event_type_to_chunk} is not splittable") if event_type_to_use is not None: if "split" not in events.columns: raise RuntimeError("Events must have a split column") for _, df in events.groupby("timeline"): to_chunk = df.type == event_type_to_chunk if not any(to_chunk): continue if event_type_to_use is None: # chunk based on max_duration segments = _segs.list_segments( df, triggers=to_chunk, duration=max_duration, stride=max_duration, stride_drop_incomplete=False, ) timepoints = [segment.start for segment in segments] else: # chunk based on train/test split of the event_type_to_use timepoints = [] events_to_use = df.loc[events.type == event_type_to_use].copy() previous = events_to_use.copy().shift(1) split_change = events_to_use.split.astype(str) != previous.split.astype(str) events_to_use["section"] = np.cumsum(split_change.values) # type: ignore for _, section in events_to_use.groupby("section"): start, end = ( section.iloc[0].start, section.iloc[-1].start + section.iloc[-1].duration, ) timepoints.extend(np.arange(start, end, max_duration)) events_to_chunk = df.loc[to_chunk] dropped_rows.extend(events_to_chunk.index) for row in events_to_chunk.itertuples(index=False): event_to_chunk = ns_event_type_to_chunk.from_dict(row) new_events = event_to_chunk._split( [t - event_to_chunk.start for t in timepoints], min_duration ) # type: ignore for new_event in new_events: new_event_dict = new_event.to_dict() # add the columns which were removed by event.from_dict() except index for k, v in row._asdict().items(): # type: ignore if k not in new_event_dict: new_event_dict[k] = v added_events.append(new_event_dict) out_events = events.copy() out_events.drop(dropped_rows, inplace=True) out_events = pd.concat([out_events, pd.DataFrame(added_events)]) out_events.reset_index(drop=True, inplace=True) return out_events
# --------------------------------------------------------------------------- # splitting helpers # ---------------------------------------------------------------------------
[docs] @dataclass class DeterministicSplitter: """Hash-based splitter that assigns a deterministic train/val/test split. Hashes each sample's unique ID to a float in [0, 1) and maps it to a split name according to cumulative ``ratios``. The assignment is stable across runs and independent of dataset order. Parameters ---------- ratios : Mapping from split name to proportion (must sum to 1). seed : Added to the hash to produce different splits. """ ratios: dict[str, float] seed: float = 0.0 def __post_init__(self) -> None: if not all(ratio > 0 for ratio in self.ratios.values()): raise ValueError(f"All ratios must be positive, got {self.ratios}") if not np.allclose(sum(self.ratios.values()), 1.0): raise ValueError(f"The sum of ratios must be equal to 1. Got {self.ratios}") def __call__(self, uid: str) -> str: hashed = int(hashlib.sha256(uid.encode()).hexdigest(), 16) rng = random.Random(hashed + self.seed) score = rng.random() cdf = np.cumsum(list(self.ratios.values())) names = list(self.ratios.keys()) for idx, cdf_val in enumerate(cdf): if score < cdf_val: return names[idx] raise ValueError