Source code for neuralset.extractors.video

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import inspect
import logging
import typing as tp
import warnings

import numpy as np
import pydantic
import torch
from exca import MapInfra
from tqdm import tqdm

from neuralset import base as nsbase
from neuralset.events import etypes as evts

from . import base as extractor_base
from . import hf
from . import image as image_extractors

logger = logging.getLogger(__name__)
# activate with:
# logging.getLogger("neuralset").setLevel(logging.DEBUG)

_VideoImage = image_extractors._VideoImage


def resamp_first_dim(data: torch.Tensor, new_first_dim: int) -> torch.Tensor:
    if data.shape[0] == new_first_dim:
        return data
    import julius

    logger.debug(
        "Resampling video embedding from %s samples to %s", data.shape[0], new_first_dim
    )
    resample = julius.resample.ResampleFrac(
        old_sr=data.shape[0],
        new_sr=new_first_dim,
    ).to(data.device)
    dims = []
    for dim in tqdm(data.reshape(data.shape[0], -1).T):
        dims.append(resample(dim.float()))
    # TODO: stack an extra frame here?
    output = torch.stack(dims).reshape(-1, *data.shape[1:])
    return output


class HuggingFaceVideoConfig(hf.HuggingFaceConfig):
    processor_kwargs: dict[str, tp.Any] | None = {"do_rescale": True}
    HF_CLASS_DEFAULTS: tp.ClassVar[dict[str, dict[str, str]]] = {
        "vjepa2": {"processor_cls_name": "AutoVideoProcessor"},
        "google/vivit": {
            "model_cls_name": "VivitModel",
            "processor_cls_name": "VivitImageProcessor",
        },
    }


[docs] class HuggingFaceVideo(extractor_base.BaseExtractor, hf.HuggingFaceMixin): """Extract video embeddings using a native HuggingFace video model. Videos are divided into clips of `clip_duration` seconds at the specified frequency. Each clip is processed by the video model, and features are aggregated over layers/tokens using the HuggingFace extractor options. Parameters ---------- model_name : str, default="MCG-NJU/videomae-base" HuggingFace video model identifier. Image models are not accepted here; use `HuggingFaceImage` for frame-by-frame video embeddings. clip_duration : float | None, default=None Duration (in seconds) of video sub-clips to process. If None, defaults to one timestep (1 / frequency). max_imsize : int | None, default=None Maximum image dimension for downsampling before processing. num_frames : int Number of frames to pass to the video model per clip. """ event_types: tp.Literal["Video"] = "Video" SUPPORTED_MODELS: tp.ClassVar[tuple[str, ...]] = ( "vjepa2", "videomae", "google/vivit", "facebook/timesformer", ) # class attributes requirements: tp.ClassVar[tuple[str, ...]] = ( "torchvision>=0.15.2", "julius>=0.2.7", ) model_name: str = "MCG-NJU/videomae-base" hf_config: HuggingFaceVideoConfig = HuggingFaceVideoConfig() clip_duration: float | None = None max_imsize: int | None = None num_frames: int infra: MapInfra = MapInfra( timeout_min=120, gpus_per_node=1, cpus_per_task=8, min_samples_per_job=128, version="v5", ) @pydantic.model_validator(mode="before") @classmethod def _reject_previous_api(cls, data: tp.Any) -> tp.Any: if isinstance(data, dict) and "image" in data: msg = ( "HuggingFaceVideo no longer accepts the previous API " "`image=HuggingFaceImage(...)`. For frame-by-frame video " "embeddings, instantiate HuggingFaceImage with event_types='Video'. " "For native video models, pass the model name directly as " "HuggingFaceVideo(model_name=..., num_frames=...)." ) raise ValueError(msg) return data @pydantic.field_validator("model_name") @classmethod def _validate_model_name(cls, model_name: str) -> str: if any(z in model_name for z in cls.SUPPORTED_MODELS): return model_name msg = ( "The HuggingFaceVideo API now only supports native video models. " "For the previous frame-by-frame API, instantiate HuggingFaceImage " f"with event_types='Video' instead of using model_name={model_name!r}." ) raise ValueError(msg) @classmethod def _exclude_from_cls_uid(cls) -> list[str]: return hf.HuggingFaceMixin._exclude_from_cls_uid() def _exclude_from_cache_uid(self) -> list[str]: return extractor_base.BaseExtractor._exclude_from_cache_uid( self ) + hf.HuggingFaceMixin._exclude_from_cache_uid(self) def _get_timed_arrays( self, events: list[evts.Video], start: float, duration: float ) -> tp.Iterable[nsbase.TimedArray]: for event, ta in zip(events, self._get_data(events)): sub = ta.with_start(event.start).overlap(start=start, duration=duration) if self.cache_n_layers is not None: sub.data = self._aggregate_layers(sub.data) yield sub @infra.apply( item_uid=lambda e: e._splittable_event_uid(), exclude_from_cache_uid="method:_exclude_from_cache_uid", ) def _get_data(self, events: list[evts.Video]) -> tp.Iterator[nsbase.TimedArray]: # read all videos of the events logging.getLogger("neuralset").setLevel(logging.DEBUG) self._warn_if_config_num_frames_mismatch() freq = events[0].frequency if self.frequency == "native" else self.frequency T = 1 / freq if self.clip_duration is None else self.clip_duration subtimes = [k / self.num_frames * T for k in reversed(range(self.num_frames))] for event in events: video = event.read() freq = self.frequency if self.frequency != "native" else event.frequency expect_frames = nsbase.Frequency(freq).to_ind(event.duration) logger.debug( "Loaded Video (duration %ss at %sfps, shape %s):\n%s", video.duration, video.fps, tuple(video.size), event.filepath, ) # time at end of sample: times = np.linspace(0, video.duration, expect_frames + 1)[1:] # samples the frames in-between the main frequency output = np.array([]) # pylint: disable=protected-access for k, t in tqdm(enumerate(times), total=len(times), desc="Encoding video"): ims = [_VideoImage(video=video, time=max(0, t - t2)) for t2 in subtimes] pil_imgs = [i.read() for i in ims] # resize if images are too big if pil_imgs and self.max_imsize is not None: factor = max(pil_imgs[0].size) / self.max_imsize if factor > 1: size = tuple(int(s / factor) for s in pil_imgs[0].size) pil_imgs = [pi.resize(size) for pi in pil_imgs] data = np.array([np.array(pi) for pi in pil_imgs]) t_embd = self._predict_hidden_states(data) if t_embd.shape[0] != 1: raise RuntimeError(f"Found several batches: {t_embd.shape}") t_embd = t_embd[0] # aggregate_tokens works on non-batched-data embd = self._aggregate_tokens(t_embd).cpu().numpy() if self.cache_n_layers is None: embd = self._aggregate_layers(embd) if not output.size: output = np.zeros((len(times),) + embd.shape) logger.debug("Created Tensor with size %s", output.shape) output[k] = embd video.close() # set first (time) dim to last output = output.transpose(list(range(1, output.ndim)) + [0]) yield nsbase.TimedArray( data=output.astype(np.float32), frequency=freq, start=nsbase._UNSET_START, duration=event.duration, ) def _predict_hidden_states(self, images: np.ndarray) -> torch.Tensor: kwargs: dict[str, tp.Any] = { self._processor_input_field(): list(images), "return_tensors": "pt", } inputs = self.processor(**kwargs) # prevent nans (happening for uniform images) image_extractors._fix_pixel_values(inputs) inputs = inputs.to(self.model_device) with torch.inference_mode(): pred = self.model(**inputs, output_hidden_states=True) states = pred.hidden_states out = torch.cat([x.unsqueeze(1) for x in states], axis=1) # type: ignore return out # B x L x ... def _processor_input_field(self) -> tp.Literal["images", "videos"]: parameters = inspect.signature(self.processor.__call__).parameters return "videos" if "videos" in parameters else "images" def _config_num_frames(self) -> int | None: config = getattr(self.model, "config", None) if config is None: return None config = getattr(config, "vision_config", config) num_frames = getattr(config, "num_frames", None) return num_frames if isinstance(num_frames, int) else None def _warn_if_config_num_frames_mismatch(self) -> None: config_num_frames = self._config_num_frames() if config_num_frames is None or config_num_frames == self.num_frames: return warnings.warn( f"Model {self.model_name!r} config expects {config_num_frames} frames, " f"but HuggingFaceVideo.num_frames={self.num_frames}.", stacklevel=2, )