# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import typing as tp
import numpy as np
import torch
from exca import MapInfra
from tqdm import tqdm
from neuralset import base as nsbase
from neuralset.events import etypes as evts
from neuralset.utils import ignore_all
from .base import BaseExtractor
from .image import HuggingFaceImage, _fix_pixel_values
logger = logging.getLogger(__name__)
# activate with:
# logging.getLogger("neuralset").setLevel(logging.DEBUG)
class _VideoImage(evts.Image):
"""Image event wrapper for extracting individual frames from a video.
This class extends the base Image event to enable frame-by-frame processing
of videos. It extracts a single frame at a specified time point and presents
it as an Image event.
Parameters
----------
start : float, default=0.0
The start time of the image event.
timeline : str, default="fake"
Timeline identifier for the image event. Uses "fake" since this is a virtual
event derived from video data rather than an original image event.
video : moviepy.editor.VideoFileClip
The loaded video object from which to extract frames. Should be a MoviePy
VideoFileClip instance with frame extraction capabilities.
time : float, default=0.0
The exact timestamp (in seconds) within the video at which to extract
the frame. Must be within [0, video.duration].
duration : float, default=1.0
The nominal duration of the image event (in seconds). By convention, it is set to 1.
filepath : str, default=""
Auto-generated filepath identifier for caching purposes. Automatically
constructed as "{video_filename}:{time:.3f}" to ensure unique cache keys
for each frame.
"""
start: float = 0.0
timeline: str = "fake"
duration: float = 1.0
video: tp.Any
time: float = 0.0
filepath: str = ""
def model_post_init(self, log__: tp.Any) -> None:
if self.filepath:
raise ValueError("Filepath is automatically filled")
# create a custom filepath for caching
self.filepath = f"{self.video.filename}:{self.time:.3f}"
super().model_post_init(log__)
def _read(self) -> tp.Any:
import PIL # noqa
# may require: pip install moviepy==2.0.0.dev2
with ignore_all():
img = self.video.get_frame(self.time)
return PIL.Image.fromarray(img.astype("uint8"))
def resamp_first_dim(data: torch.Tensor, new_first_dim: int) -> torch.Tensor:
if data.shape[0] == new_first_dim:
return data
import julius
logger.debug(
"Resampling video embedding from %s samples to %s", data.shape[0], new_first_dim
)
resample = julius.resample.ResampleFrac(
old_sr=data.shape[0],
new_sr=new_first_dim,
).to(data.device)
dims = []
for dim in tqdm(data.reshape(data.shape[0], -1).T):
dims.append(resample(dim.float()))
# TODO: stack an extra frame here?
output = torch.stack(dims).reshape(-1, *data.shape[1:])
return output
[docs]
class HuggingFaceVideo(BaseExtractor):
"""Extract video features using a HuggingFace transformer model.
This feature extractor supports two processing modes:
1. **Image-based processing**: When using an image model, videos are sampled
at the specified frequency and each frame is processed independently.
2. **Video-based processing**: When using a native video model (e.g., VideoMAE,
XClip), videos are divided into clips of `clip_duration` seconds at the
specified frequency. Each clip is processed by the video model, and features
are aggregated over time.
Parameters
----------
image : HuggingFaceImage, default=HuggingFaceImage(model_name="MCG-NJU/videomae-base")
Image or video feature extractor configuration. If `image.model_name` refers
to an image model (e.g., ViT), frames are extracted and processed independently.
If it's a video model, clips are processed using the native video architecture.
use_audio : bool, default=True
Whether to include audio alongside video frames during feature extraction.
Only applicable for models that support multimodal inputs (e.g., LLaVA-Video).
clip_duration : float | None, default=None
Duration (in seconds) of video sub-clips to process. If None, defaults to
one timestep (1 / frequency).
max_imsize : int | None, default=None
Maximum image dimension for downsampling before processing. Useful for
memory-constrained scenarios. For example, Phi-4 downsizes to 448×448
before tokenization.
layer_type : str, default=""
Specific layer extraction mode for certain models.
For XClip: Use "mit" to extract from Multi-frame Integration Transformer
layers instead of vision backbone layers.
For LLaVA models: Must be a prompt string containing the ``<video>`` token
(e.g., ``"<|user|><video><|end|><|assistant|>"``).
.. note:: The pipe characters in the example are literal LLaVA tokens.
num_frames : int | None, default=None
Number of frames to pass to the video model per clip. If None, uses the
model's default frame count (e.g., 16 for VideoMAE, 8 for XClip, 64 for VJepa2).
"""
event_types: tp.Literal["Video"] = "Video"
# class attributes
requirements: tp.ClassVar[tuple[str, ...]] = (
"torchvision>=0.15.2",
"julius>=0.2.7",
)
image: HuggingFaceImage = HuggingFaceImage(
model_name="MCG-NJU/videomae-base",
infra=MapInfra(keep_in_ram=False),
imsize=None, # type: ignore[arg-type]
)
use_audio: bool = True
clip_duration: float | None = None
max_imsize: int | None = None
layer_type: str = ""
num_frames: int | None = None
infra: MapInfra = MapInfra(
timeout_min=120,
gpus_per_node=1,
cpus_per_task=8,
min_samples_per_job=128,
version="v5",
)
def model_post_init(self, log__: tp.Any) -> None:
super().model_post_init(log__)
if self.image.infra.keep_in_ram:
msg = "video.image.infra.keep_in_ram must be False to avoid overload"
raise ValueError(msg)
for name in ["folder", "cluster"]:
val = getattr(self.image.infra, name)
if val is not None:
raise ValueError(f"image.infra.{name} must be None, (got {val!r})")
model = self.image.model_name
if "video" in model and "videomae" not in model:
msg = "Currently unclear if this supports any video model but videomae model"
raise NotImplementedError(msg)
_HFVideoModel.check_layer_type(layer_type=self.layer_type, model_name=model)
super().model_post_init(log__)
def _exclude_from_cache_uid(self) -> list[str]:
ex = super()._exclude_from_cache_uid()
im_ex = self.image._exclude_from_cache_uid()
return ex + [f"image.{n}" for n in im_ex]
def _get_data_from_image_model(
self, events: list[evts.Video]
) -> tp.Iterator[nsbase.TimedArray]:
# read all videos of the events
config = getattr(self.image.model.model, "config", object())
config = getattr(config, "vision_config", config) # xclip
if hasattr(config, "num_frames"):
name = self.image.model_name
msg = f"Model {name!r} seems to be a video model, but treated as image"
raise RuntimeError(msg)
for event in tqdm(events, desc="Computing video latents"):
video = event.read()
n_frames = int(video.duration * video.fps)
freq = event.frequency if self.frequency == "native" else self.frequency
expect_frames = nsbase.Frequency(freq).to_ind(event.duration)
logger.debug(
"Loaded Video (duration %ss at %sfps, %s frames of shape %s):\n %s",
video.duration,
video.fps,
n_frames,
tuple(video.size),
event.filepath,
)
times = np.linspace(0, video.duration, expect_frames)
# TODO warn about aspect ratio? resize leads to aspect ratio 1:1
ims = [_VideoImage(video=video, time=t) for t in times]
output = torch.Tensor([])
# pylint: disable=protected-access
k = -1
for k, embd in enumerate(
tqdm(self.image._get_data(ims), total=len(times), leave=False)
):
if not k:
output = torch.zeros(len(times), *embd.shape)
logger.debug("Created Tensor with size %s", output.shape)
output[k] = torch.Tensor(embd)
logger.debug("Finished encoding video at video frame rate")
if k != len(times) - 1:
raise RuntimeError(f"Expected {len(times)} frames, got {k + 1}")
# resample full output
if abs(output.shape[0] - expect_frames) > 1: # some flexibility allowed
output = output.to(self.image.device)
output = resamp_first_dim(output, expect_frames).cpu()
logger.debug("Resampled video embeddings at frequency %s", self.frequency)
# set first (time) dim to last
output = output.permute(list(range(1, output.dim())) + [0])
freq = event.frequency if self.frequency == "native" else self.frequency
yield nsbase.TimedArray(
data=output.cpu().numpy().astype(np.float32),
frequency=freq,
start=nsbase._UNSET_START,
duration=event.duration,
)
def _get_timed_arrays(
self, events: list[evts.Video], start: float, duration: float
) -> tp.Iterable[nsbase.TimedArray]:
for event, ta in zip(events, self._get_data(events)):
sub = ta.with_start(event.start).overlap(start=start, duration=duration)
if self.image.cache_n_layers is not None:
sub.data = self.image._aggregate_layers(sub.data)
yield sub
@infra.apply(
item_uid=lambda e: e._splittable_event_uid(),
exclude_from_cache_uid="method:_exclude_from_cache_uid",
)
def _get_data(self, events: list[evts.Video]) -> tp.Iterator[nsbase.TimedArray]:
# read all videos of the events
logging.getLogger("neuralset").setLevel(logging.DEBUG)
if not any(z in self.image.model_name for z in _HFVideoModel.MODELS):
yield from self._get_data_from_image_model(events)
return
model = _HFVideoModel(
model_name=self.image.model_name,
pretrained=self.image.pretrained,
layer_type=self.layer_type,
num_frames=self.num_frames,
)
if model.model.device.type == "cpu":
# may already be dispatched (with "accelerate")
model.model.to(self.image.device)
# videomae = 16 frames
# xclip = 8 or 16 frames (unclear)
freq = events[0].frequency if self.frequency == "native" else self.frequency
T = 1 / freq if self.clip_duration is None else self.clip_duration
subtimes = list(
k / model.num_frames * T for k in reversed(range(model.num_frames))
) # type: ignore
for event in events:
video = event.read()
audio = video.audio if self.use_audio else None
freq = self.frequency if self.frequency != "native" else event.frequency
expect_frames = nsbase.Frequency(freq).to_ind(event.duration)
logger.debug(
"Loaded Video (duration %ss at %sfps, shape %s):\n%s",
video.duration,
video.fps,
tuple(video.size),
event.filepath,
)
# time at end of sample:
times = np.linspace(0, video.duration, expect_frames + 1)[1:]
# samples the frames in-between the main frequency
output = np.array([])
# pylint: disable=protected-access
for k, t in tqdm(enumerate(times), total=len(times), desc="Encoding video"):
ims = [_VideoImage(video=video, time=max(0, t - t2)) for t2 in subtimes]
audio_clip = (
audio.subclipped(max(0, t - T), t) if audio is not None else None
)
pil_imgs = [i.read() for i in ims]
# resize if images are too big
if pil_imgs and self.max_imsize is not None:
factor = max(pil_imgs[0].size) / self.max_imsize
if factor > 1:
size = tuple(int(s / factor) for s in pil_imgs[0].size)
pil_imgs = [pi.resize(size) for pi in pil_imgs]
data = np.array([np.array(pi) for pi in pil_imgs])
t_embd = model.predict_hidden_states(data, audio_clip)
if t_embd.shape[0] != 1:
raise RuntimeError(f"Found several batches: {t_embd.shape}")
t_embd = t_embd[0] # aggregate_tokens works on non-batched-data
embd = self.image._aggregate_tokens(t_embd).cpu().numpy()
if self.image.cache_n_layers is None:
embd = self.image._aggregate_layers(embd)
if not output.size:
output = np.zeros((len(times),) + embd.shape)
logger.debug("Created Tensor with size %s", output.shape)
output[k] = embd
video.close()
# set first (time) dim to last
output = output.transpose(list(range(1, output.ndim)) + [0])
yield nsbase.TimedArray(
data=output.astype(np.float32),
frequency=freq,
start=nsbase._UNSET_START,
duration=event.duration,
)
class _HFVideoModel:
"""Wrapper that provides a unified interface for loading and using various HuggingFace
video models
Parameters
----------
model_name : str
HuggingFace model identifier.
The model will be loaded from the HuggingFace Hub. Please note that you may have to install additional dependencies to load it correctly.
pretrained : bool, default=True
Whether to load pretrained weights. If False, initializes the model with
random weights from the model configuration.
layer_type: str, default=""
Specific layer extraction mode for certain models:
- For XClip: Use "mit" to extract from Multi-frame Integration Transformer
layers instead of vision backbone layers.
- For LLaVA models: Must be a prompt string containing the "<video>" token
(e.g., "<|user|><video><|end|><|assistant|>").
num_frames : int | None, default=None
Number of frames to pass to the video model per clip. If None, uses the
model's default frame count (e.g., 16 for VideoMAE, 8 for XClip, 64 for VJepa2).
"""
MODELS = (
"vjepa2",
"videomae",
"microsoft/xclip",
"google/vivit",
"facebook/timesformer",
"LLaVA-NeXT-Video",
"LLaVA-Video",
"Phi-4",
)
# language + video models https://arxiv.org/pdf/2405.21075
def __init__(
self,
model_name: str,
pretrained: bool = True,
layer_type: str = "",
num_frames: int | None = None,
) -> None:
super().__init__()
if not any(z in model_name for z in self.MODELS):
raise ValueError(f"Model {model_name!r} is not supported")
Model: tp.Any # ignore typing as we'll override the imports
Processor: tp.Any
from transformers import AutoModel as Model
from transformers import AutoProcessor as Processor
extra: dict[str, tp.Any] = {}
processor_extra: dict[str, tp.Any] = {"do_rescale": True}
if "google/vivit" in model_name:
from transformers import VivitImageProcessor as Processor
from transformers import VivitModel as Model # type: ignore
if "LLaVA" in model_name:
from transformers import LlavaNextVideoForConditionalGeneration as Model
from transformers import LlavaNextVideoProcessor as Processor
extra = {"torch_dtype": torch.float16}
if "34B" in model_name:
extra["device_map"] = "auto" # uses accelerate
if "Phi-4" in model_name:
from transformers import AutoModelForCausalLM as Model
extra = {"_attn_implementation": "eager", "trust_remote_code": True}
processor_extra["trust_remote_code"] = True
if "vjepa2" in model_name:
from transformers import AutoVideoProcessor as Processor
self.model = Model.from_pretrained(model_name, output_hidden_states=True, **extra)
if not pretrained:
self.model = Model.from_config(self.model.config)
self.model.eval()
# use do_rescale=True -> don't use totensor
self.processor = Processor.from_pretrained(model_name, **processor_extra)
self.model_name = model_name
self.layer_type = layer_type
if "llava" in model_name.lower():
max_frames = 16 # any number works
elif "vjepa2" in model_name.lower():
max_frames = 64
elif "Phi-4" in model_name:
max_frames = 4 # TODO: make this flexible?
else:
config = self.model.config
config = getattr(config, "vision_config", config) # xclip
max_frames = config.num_frames
if num_frames is None:
self.num_frames = max_frames
else:
self.num_frames = num_frames
if self.num_frames > max_frames:
raise ValueError(
f"{model_name} only seems to supports {max_frames} frames, got {self.num_frames}"
)
self.check_layer_type(layer_type, model_name)
@staticmethod
def check_layer_type(layer_type: str, model_name: str) -> None:
if "xclip" in model_name and layer_type == "mit":
return # is ok
if "llava" in model_name.lower():
if "<video>" not in layer_type:
msg = f"For {model_name!r}, layer_type must be a prompt with the <video> token\n"
# note: best aggregation was: mean
raise ValueError(msg)
return # all good
if layer_type:
raise ValueError(f"No layer type available for {model_name!r}")
def predict(self, images: np.ndarray, audio: tp.Any | None = None) -> tp.Any:
kwargs: dict[str, tp.Any] = {"text": "", "return_tensors": "pt"}
field = "images"
if "xclip" in self.model_name:
field = "videos"
elif "llava" in self.model_name.lower():
field = "videos"
kwargs["text"] = self.layer_type
elif "vjepa2" in self.model_name:
field = "videos"
del kwargs["text"]
elif "Phi-4" in self.model_name:
import PIL
images = [PIL.Image.fromarray(img) for img in images] # type: ignore
field = "images"
prompt = "<|user|>"
for i in range(1, len(images) + 1):
prompt += f"<|image_{i}|>"
if audio is not None:
kwargs["audios"] = [(audio.to_soundarray(), audio.fps)] # type: ignore
prompt += "<|audio_1|>"
prompt += "<|end|><|assistant|>"
kwargs["text"] = prompt
kwargs[field] = list(images)
inputs = self.processor(**kwargs)
# prevent nans (happening for uniform images)
_fix_pixel_values(inputs)
inputs = inputs.to(self.model.device)
with torch.inference_mode():
pred = self.model(**inputs)
return pred
def predict_hidden_states(
self, images: np.ndarray, audio: np.ndarray | None = None
) -> torch.Tensor:
pred = self.predict(images, audio)
if "xclip" in self.model_name:
# MIT: Multi-frame Integration Transformer
is_mit = self.layer_type == "mit"
pred = pred.mit_output if is_mit else pred.vision_model_output
# [8, 13, 197, 768] for vision model, [1, 2, 8, 512] for mit model
states = pred.hidden_states
out = torch.cat([x.unsqueeze(1) for x in states], axis=1) # type: ignore
if "xclip" in self.model_name and not self.layer_type:
out = out[[-1], ...] # last batch/timepoint only
return out # B x L x ...