Source code for neuralset.events.transforms.text

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import typing as tp
from collections import deque
from functools import lru_cache

import numpy as np
import pandas as pd
import pydantic
import torch
from tqdm import tqdm

from neuralset import segments as _segs

from .. import etypes as ev
from ..study import EventsTransform
from .utils import MISSING_SENTENCE, TextWordMatcher, _extract_sentences, parse_text

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# EnsureTexts — canonical way to guarantee Text events exist
# ---------------------------------------------------------------------------


@lru_cache
def _get_punct_model() -> tp.Any:
    from deepmultilingualpunctuation import PunctuationModel

    return PunctuationModel()


[docs] class EnsureTexts(EventsTransform): """Create Text events from Words if not already present. Parameters ---------- punctuation : str or None ``"spacy"`` — capitalize first letter of each sentence and add language-appropriate sentence-ending punctuation via spaCy sentence segmentation. ``"fullstop"`` — ``oliverguhr/fullstop-punctuation-multilang-large`` DL punctuation restoration (requires ``deepmultilingualpunctuation``). May need a GPU. ``None`` — plain space-join (no punctuation added). """ punctuation: tp.Literal["spacy", "fullstop"] | None = "spacy" def _run(self, events: pd.DataFrame) -> pd.DataFrame: text_names = ev.EventTypesHelper("Text").names if events.type.isin(text_names).any(): return events word_names = ev.EventTypesHelper("Word").names words = events.loc[events.type.isin(word_names)].sort_values( ["timeline", "start"] ) if words.empty: return events text_rows: list[dict[str, tp.Any]] = [] for timeline, word_group in words.groupby("timeline"): raw = " ".join(word_group.text.values) text_str = self._punctuate(raw, word_group) start = float(word_group.start.min()) stop = float((word_group.start + word_group.duration).max()) text_rows.append( dict( type="Text", start=start, duration=stop - start, timeline=timeline, text=text_str, language=( word_group.language.iloc[0] if "language" in word_group.columns else "" ), ) ) return pd.concat([events, pd.DataFrame(text_rows)], ignore_index=True) def _punctuate(self, raw: str, words_df: pd.DataFrame) -> str: if self.punctuation is None: return raw if self.punctuation == "fullstop": return _add_punctuation(raw) lang = words_df.language.iloc[0] if "language" in words_df.columns else "" doc = parse_text(raw, lang) periods = { "chinese": "\u3002", "zh": "\u3002", "japanese": "\u3002", "ja": "\u3002", } period = periods.get(lang, ".") sep = period + " " parts: list[str] = [] for sent in doc.sents: s = sent.text.strip() if s: s = s[0].upper() + s[1:] parts.append(s) text = sep.join(parts) if text and not text.endswith(period): text += period return text
@lru_cache def _add_punctuation(text: str) -> str: """Cached DL punctuation restoration (shared across subjects with same text).""" model = _get_punct_model() return model.restore_punctuation(text)
[docs] class AddSentenceToWords(EventsTransform): """ Adds sentence-level information to word events based on Text rows. This transform processes a DataFrame containing word-level (Word) and text-level (Text) events. For each sentence found in the Text rows, it: 1. Creates a new Sentence row for each sentence. 2. Assigns `sentence` and `sentence_char` annotations to Word rows to indicate which sentence each word belongs to, and which character the word starts at in the sentence. Parameters ---------- max_unmatched_ratio : float Maximum allowed ratio of word rows that do not match any sentence. Raises an error if this ratio is exceeded. override_sentences : bool, default=False Whether to replace existing Sentence rows if they are already present. """ max_unmatched_ratio: float = 0.0 # raises if did not match enough words override_sentences: bool = False @classmethod def _exclude_from_cls_uid(cls) -> list[str]: return super()._exclude_from_cls_uid() + ["max_unmatched_ratio"] def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) if self.max_unmatched_ratio < 0 or self.max_unmatched_ratio >= 1: raise ValueError(f"{self.max_unmatched_ratio=}, must be in [0, 1)") def _run(self, events: pd.DataFrame) -> pd.DataFrame: """Add sentence information to each word event by parsing the corresponding full text""" if "Sentence" in events.type.unique(): if not self.override_sentences: msg = "Sentence already present in events dataframe" logger.debug(msg) return events events = events[events.type != "Sentence"] if "timeline" in events.columns and len(events.timeline.unique()) > 1: timelines = [] desc = "Add sentence to Word based on Text" # 1 timeline at a time tl_dfs = events.groupby("timeline", sort=False) for _, subevents in tqdm(tl_dfs, desc=desc, mininterval=10): timelines.append(self._run(subevents)) return pd.concat(timelines, ignore_index=True) if "Word" not in events.type.unique(): logger.info("No Word events found, skipping") return events contexts = events.loc[events.type == "Text"] if contexts.empty: msg = "No Text event in dataframe, add it in the study " msg += "or use 'EnsureTexts' transform to create Text events from Words" raise RuntimeError(msg) events = events.copy(deep=True) # avoid border effect wtypes = ev.EventTypesHelper("Word") words = events[events.type.isin(wtypes.names)] events.loc[:, "sentence_char"] = np.nan events["sentence"] = "" sentences = [] for context in contexts.itertuples(): # find words that are enclosed in this context (requires unique timeline) encl = _segs.find_enclosed( events, start=float(context.start), # type: ignore[arg-type] duration=float(context.duration), # type: ignore[arg-type] ) sub = events.loc[encl] sel = sub[sub.type.isin(wtypes.names)].index if not len(sel): raise ValueError("No word overlapping with context") wordseq = words.loc[sel].text.tolist() lang = getattr(context, "language", None) if not isinstance(lang, str): raise ValueError(f"Need language for Text field {context}") context_text = getattr(context, "text", None) if not isinstance(context_text, str): raise ValueError(f"Need text for Text field {context}") info = pd.DataFrame( TextWordMatcher(context_text, language=lang).match(wordseq), index=sel, ) events.loc[sel, info.columns] = info # create sentence events context_sentences = [s.to_dict() for s in _extract_sentences(events)] subject = getattr(context, "subject", None) if subject is not None: for s in context_sentences: s["subject"] = subject sentences.extend(context_sentences) sentences = [s for s in sentences if s["text"] != MISSING_SENTENCE] sentence_df = pd.DataFrame(sentences) events = pd.concat([events, sentence_df], ignore_index=True) events = events.sort_values("start") events = events.reset_index(drop=True) words = events[events.type.isin(wtypes.names)] if len(words) == 0: return events ratio = sum(not s or not isinstance(s, str) for s in words.sentence) / len(words) if ratio > self.max_unmatched_ratio: max_unmatched_ratio = self.max_unmatched_ratio cls = self.__class__.__name__ msg = f"Ratio of unmatched words is {ratio:.4f} on {len(words)} words " msg += f"while {cls}.{max_unmatched_ratio=}" raise RuntimeError(msg) return events
[docs] class AddContextToWords(EventsTransform): """Add a context field to the events dataframe, for each word event, by concatenating the sentence fields. .. warning:: **Unstable API** — the context representation (per-word concatenated strings) will be replaced with compact indices in a future release. Parameters ---------- sentence_only: bool only use current sentence as context max_context_len: None or int if not None, caps the context len to a given number of words (counted through whitespaces) split_field: str field on which to reset contexts. If empty, context is only reset for new timelines. """ sentence_only: bool = True # only use context from current sentence max_context_len: int | None = None # cut the context after given number of words split_field: str = "split" def _run(self, events: pd.DataFrame) -> pd.DataFrame: if hasattr(events, "context"): # make sure it is typed as str events["context"] = events["context"].fillna("").astype(str) wtypes = ev.EventTypesHelper("Word") words = events.loc[events.type.isin(wtypes.names), :] last_word: tp.Any = None contexts = [] desc = "Add context to words" worditer: tp.Iterator[ev.Word] = words.itertuples(index=False) # type: ignore sfield = self.split_field if sfield and sfield not in words.columns: raise ValueError(f"split_field {sfield!r} is not part of dataframe columns") for word in tqdm(worditer, total=len(words), desc=desc, mininterval=10): if last_word is not None: # check order as a security same_timeline = last_word.timeline == word.timeline if same_timeline and (word.start < last_word.start): msg = "Words are not in increasing order " msg += f"({word} follows {last_word})" raise ValueError(msg) # reset if split differs or timeline differs splits = [getattr(w, sfield, "") for w in (word, last_word)] if splits[0] != splits[1] or not same_timeline: last_word = None # restart # word is not correctly match, let's not add a context has_sent = isinstance(word.sentence, str) and word.sentence if word.sentence_char is None or np.isnan(word.sentence_char) or not has_sent: contexts.append("") continue # first word, restart parts if last_word is None: past_parts: deque[str] = deque(maxlen=self.max_context_len) start_char = 0 # assumes not splitting within sentences # not first word from now on if last_word is not None: non_increasing_char = word.sentence_char <= last_word.sentence_char if word.sentence != last_word.sentence or non_increasing_char: # new sentence if self.sentence_only: past_parts.clear() else: # add end of sentence: past_parts[-1] += last_word.sentence[start_char:] start_char = 0 elif past_parts: # same sentence # append up to current character to the last context part last_char = int(word.sentence_char) past_parts[-1] += word.sentence[start_char:last_char] start_char = last_char # reset context with timeline + check ordering for safety last_char = int(word.sentence_char) + len(word.text) new = word.sentence[start_char:last_char] contexts.append("".join(past_parts) + new) past_parts.append(new) start_char = last_char last_word = word # set new context column events.loc[words.index, "context"] = contexts return events
[docs] class AddConcatenationContext(EventsTransform): """ Adds contextual information to events by concatenating previous events of the same type. .. warning:: **Unstable API** — the context representation will be replaced with compact indices in a future release. This transform iterates over events of a specified type (default "Word") and creates a `context` column in the DataFrame. For each event, the context consists of the concatenated texts of all previous events in the same chunk, where chunks are determined by timeline changes, split changes, or sentence boundaries (if `sentence_only=True`). Optionally, the context length can be limited by `max_context_len`. .. note:: if an event is missing (eg. a previous word event) it will be missing in the context. Use AddContextToWords for a more careful consideration of context, and the addition of punctuation. Parameters ---------- event_type : str, default="Word" Type of event to use for building context. sentence_only : bool, default=False If True, chunks are defined by sentence boundaries; otherwise, by timeline and split changes. max_context_len : int | None, default=None Maximum number of previous events to include in the context. If None, all previous events in the chunk are used. split_field : str, default="split" Column name used to detect split boundaries when creating chunks. """ event_type: str = "Word" sentence_only: bool = False max_context_len: int | None = None split_field: str = "split" def model_post_init(self, log__: tp.Any) -> None: super().model_post_init(log__) if self.event_type not in ev.Event._CLASSES: raise TypeError(f"Event type {self.event_type} not found in events") def _run(self, events: pd.DataFrame) -> pd.DataFrame: """In place: adds concatenation of previous words to context.""" from collections import deque words = events.loc[events.type == self.event_type].copy() # identify chunks previous = words.copy().shift(1) timeline_change = words.timeline.astype(str) != previous.timeline.astype(str) chunk_change = timeline_change if self.split_field in words.columns: prev_split = previous[self.split_field].astype(str) split_change = words[self.split_field].astype(str) != prev_split chunk_change = chunk_change | split_change if self.sentence_only: sentence_change = words.sequence_id != previous.sequence_id chunk_change = chunk_change | sentence_change words.loc[words.index, "chunk"] = np.cumsum(chunk_change) for _, df in words.groupby("chunk", sort=False): if (df.start.diff() < 0).any(): # type: ignore[operator] raise ValueError("Events should be ordered by start time") texts = df["text"].tolist() chunk_contexts: list[str] = [] # Build cumulative context strings, optionally capped to a sliding window. window: deque[str] | list[str] = ( deque(maxlen=self.max_context_len) if self.max_context_len is not None else [] ) for word in texts: window.append(word) chunk_contexts.append(" ".join(window)) events.loc[df.index, "context"] = chunk_contexts return events
[docs] class AddSummary(EventsTransform): """ Generate concise summaries for Text events using a pretrained language model. This transform processes events of type "Text" and generates a summary for each using a large language model (LLM). Summaries are added as new events of type "Summary" in the DataFrame. The summarization prompt ensures that the model returns exactly the requested number of sentences with a precise description of the content. Parameters ---------- model_name : str, default="meta-llama/Llama-3.2-3B-Instruct" Hugging Face model identifier used for text summarization. n_sentences_requested : int, default=3 Number of sentences to include in each summary. """ model_name: str = "meta-llama/Llama-3.2-3B-Instruct" n_sentences_requested: int = 3 _pipeline: tp.Any = pydantic.PrivateAttr() @property def pipeline(self): from transformers import pipeline if not hasattr(self, "_pipeline"): self._pipeline = pipeline( "text-generation", model=self.model_name, torch_dtype=torch.bfloat16, device_map="auto", ) return self._pipeline def summarize(self, text): prompt = f"""Summarize the following text in {self.n_sentences_requested} sentences. Be as precise about what is going on. Simply return the summary, without any introduction. """ messages = [ {"role": "system", "content": "You are a professionnal text summarizer."}, {"role": "user", "content": prompt + text}, ] outputs = self.pipeline( messages, max_new_tokens=256, ) return outputs[0]["generated_text"][-1]["content"] def _run(self, events: pd.DataFrame) -> pd.DataFrame: texts = events.loc[events.type == "Text"] summaries = [] desc = "Add summary to Text events" for text in tqdm(texts.itertuples(), total=len(texts), desc=desc): text_str = getattr(text, "text", None) if not isinstance(text_str, str): continue summarized = self.summarize(text_str) summary = text._replace(type="Summary", text=summarized) # type: ignore summaries.append(summary) events = pd.concat([events, pd.DataFrame(summaries)], ignore_index=True) return events
[docs] class AddPhonemes(EventsTransform): """ Add phoneme information to events. """ def _run(self, events: pd.DataFrame) -> pd.DataFrame: raise NotImplementedError
[docs] class AddPartOfSpeech(EventsTransform): """ Add Part-Of-Speech (POS) tags to events. """ def _run(self, events: pd.DataFrame) -> pd.DataFrame: raise NotImplementedError