Source code for neuralset.extractors.meta
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import hashlib
import logging
import tempfile
import typing as tp
import exca
import numpy as np
import pandas as pd
import pydantic
import torch
from neuralset import events as _ev
from ..events import Event
from ..events.utils import extract_events
from . import base
from .image import HuggingFaceImage
from .text import HuggingFaceText
logger = logging.getLogger(__name__)
[docs]
class TimeAggregatedExtractor(base.BaseStatic):
"""Remove the time dimension of a dynamic extractor, either by summing/averaging or by
selecting the first, middle or last time point.
NOTE: This is not exactly a static extractor because its output depends on the start and duration
of the window (whereas static extractors only depend on the event). Hence, the get_static method
is not implemented.
Parameters
----------
time_aggregation: str
How to aggregate the time dimension.
Can be "sum", "mean", "first", "middle", "last" or an integer.
n_groups_concat: int | None
If provided, the time dimension is divided into `n_groups` equal parts and the aggregation
is carried out within each group, before being concatenated.
extractor: BaseExtractor
The extractor to aggregate.
"""
time_aggregation: tp.Literal["sum", "mean", "first", "last"] = "mean"
n_groups_concat: pydantic.PositiveInt | None = None
event_types: str | tuple[str, ...] = "Event"
extractor: base.BaseExtractor
def model_post_init(self, log__: tp.Any) -> None:
self.event_types = self.extractor.event_types
if self.frequency != 0:
name = self.__class__.__name__
raise ValueError(f"{name}.frequency must be 0")
return super().model_post_init(log__)
def prepare(self, events: pd.DataFrame) -> None:
self.extractor.prepare(events)
def _aggregate(self, out: torch.Tensor) -> torch.Tensor:
match self.time_aggregation:
case "sum":
return out.sum(-1)
case "mean":
return out.mean(-1)
case "first":
return out[..., 0]
case "last":
return out[..., -1]
case other:
raise ValueError(f"Unknown time_aggregation: {other}")
def __call__(
self,
events: tp.Any, # too complex: pd.DataFrame | list | dict | Event,
start: float,
duration: float,
trigger: Event | pd.Series | dict | None = None,
) -> torch.Tensor:
out = self.extractor(events, start, duration, trigger)
if self.n_groups_concat is None:
return self._aggregate(out)
if self.n_groups_concat > out.shape[-1]:
raise ValueError(
f"n_groups_concat ({self.n_groups_concat}) cannot be greater than "
f"the number of time points ({out.shape[-1]})."
)
return torch.cat(
[
self._aggregate(sub)
for sub in torch.tensor_split(out, self.n_groups_concat, dim=-1)
]
)
def get_static(self, *args: tp.Any, **kwargs: tp.Any) -> torch.Tensor:
msg = f"{type(self).__name__}.get_static should not be called as the extractor is dynamic"
raise RuntimeError(msg)
[docs]
class AggregatedExtractor(base.BaseExtractor):
"""Aggregate multiple extractors along the specified dimension.
Note that self.extractor_aggregation determines how the extractors are aggregated for a given event,
whereas self.aggregation determines how different events are aggregated
(after the extractors have been aggregated).
"""
event_types: str | tuple[str, ...] = "Event"
extractors: list[base.BaseExtractor]
extractor_aggregation: tp.Literal["cat", "stack", "mean", "sum"] = "cat"
frequency: tp.Literal["native"] = "native" # defered to sub-extractors
def model_post_init(self, log__: tp.Any) -> None:
"""Check that extractors are all static or all dynamic."""
fts = self.extractors
static_count = sum(isinstance(f, base.BaseStatic) for f in fts)
if static_count not in [0, len(fts)]:
raise ValueError("Extractors must be either all static or all dynamic.")
if not static_count: # dynamic
frequencies = set(f.frequency for f in self.extractors)
if len(frequencies) > 1:
raise ValueError("All extractors must have the same frequency.")
all_event_types = set(c for f in fts for c in f._event_types_helper.classes)
self.event_types = tuple(c.__name__ for c in all_event_types)
super().model_post_init(log__)
def prepare(self, events: pd.DataFrame) -> None:
for extractor in self.extractors:
extractor.prepare(events)
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> torch.Tensor:
out = [feat(*args, **kwargs) for feat in self.extractors]
aggreg = self.extractor_aggregation
if aggreg == "sum":
return sum(out) # type: ignore
if aggreg == "mean":
return sum(out) / len(out) # type: ignore
return getattr(torch, aggreg)(out, dim=0) # type: ignore
[docs]
class ExtractorPCA(base.BaseStatic):
"""Applies a PCA to another extractor's data
The underlying extractor is first computed through the prepare method, and then
the current extractor applies the PCA on it.
Both caches are stored.
Parameters
----------
extractor: Extractor
the underlying extractor on which the PCA must be applied
n_components: int
the number of components of the PCA
whiten: bool
whether the whiten post PCA
"""
extractor: base.BaseExtractor
n_components: int
whiten: bool = True
event_types: str | tuple[str, ...] = "Event"
infra: exca.MapInfra = exca.MapInfra()
_uid: str = pydantic.PrivateAttr("")
def model_post_init(self, log__: tp.Any) -> None:
self.event_types = self.extractor.event_types
self.aggregation = self.extractor.aggregation
self.frequency = self.extractor.frequency # type: ignore
super().model_post_init(log__)
if self.infra.cluster is not None:
raise ValueError(f"Cannot use a cluster on {self!r}")
if not isinstance(self.extractor, base.BaseStatic):
raise NotImplementedError("Cannot handle non-static extractors for now")
if not hasattr(self.extractor, "infra"):
raise NotImplementedError("Cannot handle extractor with no infra")
if self.infra.folder is None:
if self.extractor.infra.folder is not None: # type: ignore
self.infra.folder = self.extractor.infra.folder # type: ignore
msg = "Setting ExtractorPCA infra folder to the underlying extractor's"
logger.warning(msg)
def _prepare_pca_uid(self, obj: tp.Any) -> dict[str, _ev.Event]:
pca_events = {
self.extractor.infra.item_uid(e): e # type: ignore
for e in self._event_types_helper.extract(obj)
}
# compute a unique hash for this set of data
m = hashlib.sha256()
for uid in sorted(pca_events):
m.update(uid.encode("utf8"))
self._uid = f"{m.hexdigest()[:8]},{len(pca_events)}"
logger.debug("%r.prepare called for uid=%r", self, self._uid)
# use this to have a specific folder for caching
cache = self.infra.cache_dict
pca_events = {self._to_uid(e): e for e in pca_events.values()}
if cache: # if cache is already filled, we should be done
missing = set(pca_events) - set(cache)
if not missing:
return {}
if len(missing) != len(pca_events):
msg = "Cache exists but with missing items, something went wrong"
msg += f"\n(eg: missing {list(missing)[0]} in {cache.folder})"
raise RuntimeError(msg)
return pca_events
def prepare(self, obj: tp.Any) -> None:
pca_events = self._prepare_pca_uid(obj)
if not pca_events:
logger.debug("In %r, all events for uid=%r are cached", self, self._uid)
return # all done
# compute a unique hash for this set of data
from sklearn.decomposition import PCA
# Note: all the PCA has to be done on prepare (in the main thread)
# because it cannot be split and distributed onto many machines
events = list(pca_events.values())
self.extractor.prepare(events)
start = min(e.start for e in events)
duration = max(e.start + e.duration for e in events) - start
tas = list(
self.extractor._get_timed_arrays(events, start=start, duration=duration)
)
if any(ta.frequency for ta in tas):
raise RuntimeError("Dynamic extractors are not currently supported")
data = [ta.data.ravel() for ta in tas]
pca = PCA(n_components=self.n_components, whiten=self.whiten)
pca_data = pca.fit_transform(np.array(data))
if len(events) != len(pca_data):
raise RuntimeError("Something went wrong")
# write to cache
done = set()
with self.infra.cache_dict.writer() as w:
for i_uid, d in zip(pca_events, pca_data):
if i_uid not in done:
w[i_uid] = d
done.add(i_uid)
logger.debug("%r.prepare finished with uid=%r", self, self._uid)
def _to_uid(self, event: _ev.Event) -> str:
if not self._uid:
msg = f"prepare method must be called first for extractor {self!r}"
raise RuntimeError(msg)
uid = f"{self._uid},{self.extractor.infra.item_uid(event)}" # type: ignore
# apply item_uid because of automatic uid shortening
return self.infra.item_uid(uid)
@infra.apply(item_uid=str, exclude_from_cache_uid="method:_exclude_from_cache_uid")
def _get_data(self, uids: tp.Iterable[str]) -> np.ndarray:
uids = list(uids)
msg = "Events should have been prepared first, trying to recover "
msg += f"{uids[0]} (prepare uid: {self._uid!r})"
raise RuntimeError(msg)
def _get_timed_arrays(
self, events: list[_ev.Event], start: float, duration: float
) -> tp.Iterable[base.TimedArray]:
uids = [self._to_uid(e) for e in events]
for event, d in zip(events, self._get_data(uids)):
yield base.TimedArray(
data=d, frequency=0.0, start=event.start, duration=event.duration
)
[docs]
class HuggingFacePCA(ExtractorPCA):
"""Applies a PCA to the underlying HuggingFace extractor.
The underlying extractor is first computed through the prepare method, and then
the current extractor applies the PCA on it.
Compared to the ExtractorPCA extractor, HuggingFacePCA handles caching of multiple
layers at once in the cache.
By default, the hugging face extractor cache is deleted afterwards.
Parameters
----------
extractor: HuggingFace Extractor
the underlying extractor on which the PCA must be applied
n_components: int
the number of components of the PCA
whiten: bool
whether the whiten post PCA
use_tmp_cache: bool
whether to use a temporary cache folder for the underlying
extractor that gets deleted afterwards
"""
extractor: base.BaseExtractor
# uses a temporary cache for the extractor which gets deleted afterwards
use_tmp_cache: bool = True
def model_post_init(self, log__: tp.Any) -> None:
if not isinstance(self.extractor, (HuggingFaceText, HuggingFaceImage)):
raise TypeError("Only HuggingFaceText and HuggingFaceImage are supported")
if self.extractor.infra.folder is None:
raise RuntimeError("extractor's infra folder should be provided")
super().model_post_init(log__)
@classmethod
def _exclude_from_cls_uid(cls) -> list[str]:
return super()._exclude_from_cls_uid() + ["use_tmp_cache"]
def _exclude_from_cache_uid(self) -> list[str]:
excl = [f"extractor.{x}" for x in self.extractor._exclude_from_cache_uid()]
if "extractor.frequency" in excl:
excl.append("frequency")
return super()._exclude_from_cache_uid() + excl
def prepare(self, obj: tp.Any) -> None:
pca_events = self._prepare_pca_uid(obj)
if not pca_events:
logger.debug("In %r, all events for uid=%r are cached", self, self._uid)
return # all done
from sklearn.decomposition import PCA
# Note: all the PCA has to be done on prepare (in the main thread)
# because it cannot be split and distributed onto many machines
if not isinstance(self.extractor, base.HuggingFaceMixin): # for typing
raise TypeError("Only HuggingFaceText and HuggingFaceImage are supported")
events = list(pca_events.values())
basefolder = self.extractor.infra.folder
with tempfile.TemporaryDirectory(prefix="HF-PCA-tmp", dir=basefolder) as tmp:
# change to a temporary folder folder that will be deleted
feat_folder = tmp if self.use_tmp_cache else basefolder
feat = self.extractor.infra.clone_obj(**{"infra.folder": feat_folder})
feat.prepare(events)
embds = list(feat._get_data(events)) # N x Layer x *Embd
# note: this may be too big to keep in memory, so we dont turn it to array
# (keep as list of memmaps)
# compute PCA per layer
layers = []
for k in range(embds[0].shape[0]):
layer = np.array([e[k].ravel() for e in embds]) # N * prod(*Embd)
pca = PCA(n_components=self.n_components, whiten=self.whiten)
layers.append(pca.fit_transform(layer))
del layer # free memory if need be
embds.clear()
# avoid keeping files open, explicitely delete cache
del feat.infra._state.cache_dict
del feat
# back to N * Layer * prod(*Embd)
pca_embds = np.array(layers).transpose((1, 0, 2))
if len(events) != len(pca_embds):
raise RuntimeError("Something went wrong")
# write to cache
done = set()
with self.infra.cache_dict.writer() as w:
for i_uid, d in zip(pca_events, pca_embds):
if i_uid not in done:
w[i_uid] = d
done.add(i_uid)
logger.debug("%r.prepare finished with uid=%r", self, self._uid)
def _get_timed_arrays(
self, events: list[_ev.Event], start: float, duration: float
) -> tp.Iterable[base.TimedArray]:
uids = [self._to_uid(e) for e in events]
feat = self.extractor
if not isinstance(feat, base.HuggingFaceMixin): # for typing
raise TypeError("Only HuggingFaceText and HuggingFaceImage are supported")
for event, d in zip(events, self._get_data(uids)):
if feat.cache_n_layers is not None:
d = feat._aggregate_layers(d)
yield base.TimedArray(
data=d, frequency=0, start=event.start, duration=event.duration
)
[docs]
class CroppedExtractor(base.BaseStatic): # can be static or not
"""Crop a extractor to a given offset and duration.
Parameters
----------
extractor: BaseExtractor
The extractor to crop.
offset: float
The offset (in seconds) from the start of the event to begin the crop.
duration: PositiveFloat | None
The duration (in seconds) of the crop. If None, the crop extends to the end of the event.
frequency: Literal["native"]
The frequency of the cropped extractor. Must be "native". Never used
"""
event_types: str | tuple[str, ...] = "Event"
extractor: base.BaseExtractor
offset: float = 0
duration: pydantic.PositiveFloat | None = None
frequency: tp.Literal["native"] = "native" # type: ignore
def model_post_init(self, log__: tp.Any) -> None:
self.event_types = self.extractor.event_types
self.frequency = self.extractor.frequency # type: ignore
super().model_post_init(log__)
def __call__(
self,
events: tp.Any,
start: float,
duration: float,
trigger: Event | pd.Series | dict | None = None,
) -> torch.Tensor:
if self.duration is not None and (self.duration > duration - self.offset):
msg = f"Crop duration ({self.duration}) cannot be greater than "
msg += "segment duration minus offset ({duration} - {self.offset})"
raise ValueError(msg)
elif self.offset >= duration:
msg = f"Crop offset ({self.offset}) must be less than event duration ({duration})"
raise ValueError(msg)
duration_ = duration - self.offset if self.duration is None else self.duration
return self.extractor(events, start + self.offset, duration_, trigger)
def prepare(self, obj: tp.Any) -> None:
self.extractor.prepare(obj)
[docs]
class ToStatic(base.BaseStatic):
"""
Crop a extractor by a given offset and duration.
Parameters
----------
extractor: BaseExtractor
The extractor to crop.
"""
extractor: base.BaseExtractor
event_types: str | tuple[str, ...] = "Event"
frequency: pydantic.PositiveFloat = 0.0
aggregation: tp.Literal["trigger"] = "trigger"
def model_post_init(self, context: tp.Any) -> None:
if isinstance(self.extractor, base.BaseStatic):
raise ValueError("ToStatic cannot crop a static extractor as it is timeless.")
if self.extractor.aggregation != "single":
raise NotImplementedError(
f"ToStatic only accept extractor with `single` aggregation, got {self.extractor.aggregation}"
)
self.event_types = self.extractor.event_types
super().model_post_init(context)
def __call__(
self,
events: tp.Any,
start: float,
duration: float,
trigger: Event | pd.Series | dict,
) -> torch.Tensor:
if not isinstance(trigger, Event):
triggers = extract_events(trigger)
if len(triggers) != 1:
msg = f"trigger must be a single event, got {len(triggers)} from {trigger!r}"
raise RuntimeError(msg)
trigger = triggers[0]
if trigger.start < start or trigger.start > start + duration:
msg = (
f"trigger {trigger.start} outside segment [{start}, {start + duration}]."
)
raise RuntimeError(msg)
out = self.extractor(events, trigger.start, 0, None)
if out.shape[-1] != 1:
msg = "something went wrong, the extractor output should have 1 time sample"
raise RuntimeError(msg)
return out[..., 0]
def prepare(self, obj: tp.Any) -> None:
self.extractor.prepare(obj)