# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import typing as tp
import numpy as np
import pydantic
import torch
from exca import MapInfra
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from neuralset.events import etypes
from neuralset.utils import warn_once
from .base import BaseStatic, HuggingFaceMixin
logger = logging.getLogger(__name__)
CLUSTER_DEFAULTS: dict[str, tp.Any] = dict(
timeout_min=25,
gpus_per_node=1,
cpus_per_task=8,
min_samples_per_job=4096,
)
def _fix_pixel_values(inputs: dict[str, tp.Any]) -> None:
# prevent nans (happening for uniform images)
if "pixel_values" in inputs:
nans = inputs["pixel_values"].isnan()
if nans.any():
inputs["pixel_values"][nans] = 0
inputs["pixel_values"] = inputs["pixel_values"].float()
class _HuggingFace(nn.Module):
"""Wrapper that provides a unified interface for loading and using various HuggingFace
image models (ViT, DINOv2, CLIP, etc.) with support for hidden state extraction
from all layers.
Parameters
----------
model_name : str
HuggingFace model identifier (e.g., "facebook/dinov2-base").
The model will be loaded from the HuggingFace Hub. Please note that you may have to install additional dependencies to load it correctly.
output_hidden_states : bool, default=False
Whether to extract hidden states from all transformer layers. If False, only the hidden state from the
last layer is returned.
pretrained : bool, default=True
Whether to load pretrained weights. If False, initializes the model with
random weights from the model configuration.
"""
def __init__(
self,
model_name: str,
output_hidden_states: bool = False,
pretrained: bool = True,
) -> None:
super().__init__()
Model: tp.Any # ignore typing as we'll override the imports
Processor: tp.Any
from transformers import AutoModel as Model
from transformers import AutoProcessor as Processor
if model_name == "facebook/dpt-dinov2-base-kitti":
from transformers import DPTForDepthEstimation as Model
try:
self.model = Model.from_pretrained(
model_name, output_hidden_states=output_hidden_states
)
except ValueError as e:
# handle specific cases
if "VisionEncoderDecoderConfig" in str(e):
from transformers import VisionEncoderDecoderModel as Model
from transformers import ViTImageProcessor as Processor
elif "vit-hybrid" in str(e):
from transformers import ViTHybridForImageClassification as Model
from transformers import ViTHybridImageProcessor as Processor
elif "UperNetConfig" in str(e):
from transformers import UperNetForSemanticSegmentation as Model
self.model = Model.from_pretrained(
model_name, output_hidden_states=output_hidden_states
)
if not pretrained:
self.model = Model.from_config(self.model.config)
self.model.eval()
# do_rescale=False because ToTensor does the rescaling
self.processor = Processor.from_pretrained(model_name, do_rescale=False)
self.model_name = model_name
def _full_predict( # return the raw output, used in tests
self, images: torch.Tensor, text: str | list[str] = ""
) -> tp.Any:
kwargs: dict[str, tp.Any] = dict(
images=[i.float() for i in images], return_tensors="pt"
)
if text:
kwargs["text"] = text
inputs = self.processor(**kwargs)
_fix_pixel_values(inputs)
inputs = inputs.to(self.model.device)
with torch.inference_mode():
pred = self.model(**inputs)
return pred
def forward(self, images) -> torch.Tensor:
pred = self._full_predict(images)
pred = getattr(pred, "vision_model_output", pred) # for clip
outputs = pred.last_hidden_state
return outputs
class _ImageDataset(Dataset):
"""PyTorch Dataset for loading and transforming image events.
This dataset wraps a sequence of image events and applies optional transformations
to each image when accessed. It is primarily used by the BaseImage class
to efficiently compute image features in batches.
"""
def __init__(self, events: tp.Sequence[etypes.Image], transform=None):
self.events = events
self.transform = transform
def __len__(self) -> int:
return len(self.events)
def __getitem__(self, idx: int):
try:
image = self.events[idx].read()
if self.transform:
image = self.transform(image)
except:
logger.warning("Failed to process image event %s", self.events[idx])
raise
return image
@staticmethod
def collate_fn(images: list[torch.Tensor]) -> tp.Any:
# we can't concatenate if the outputs have different sizes
# for huggingface -> transform is applied later
if all(i.shape == images[0].shape for i in images):
return torch.stack(images)
return images
class BaseImage(BaseStatic, HuggingFaceMixin):
"""Base class for computing features from image events using batch processing.
This class provides the infrastructure for extracting features from sequences of images
using neural network models. It handles batching, device management and optional image
resizing. Subclasses must implement `_extract_batched_latents` to define
the specific feature extraction logic.
Parameters
----------
batch_size : int, default=32
Number of images to process in a batch.
imsize : int | None, default to None
Target size for resizing images before processing. If specified, images are
resized to (imsize, imsize). If None, original image dimensions are preserved.
"""
# class attributes
event_types: tp.Literal["Image"] = "Image"
requirements: tp.ClassVar[tuple[str, ...]] = (
"torchvision>=0.15.2",
"transformers>=4.29.2",
"pillow>=9.2.0",
)
# extractor attributes
batch_size: int = 32
imsize: int | None = None
_model: nn.Module = pydantic.PrivateAttr() # initialized later
@classmethod
def _exclude_from_cls_uid(cls) -> list[str]:
return (
["batch_size"]
+ BaseStatic._exclude_from_cls_uid()
+ HuggingFaceMixin._exclude_from_cls_uid()
)
def _exclude_from_cache_uid(self) -> list[str]:
return BaseStatic._exclude_from_cache_uid(
self
) + HuggingFaceMixin._exclude_from_cache_uid(self)
def _make_transform(self) -> tp.Any:
from torchvision import transforms
transfs = [transforms.ToTensor()]
if self.imsize is not None:
transfs = [transforms.Resize(self.imsize)] + transfs
return transforms.Compose(transfs)
def _get_data(self, events: tp.Sequence[etypes.Image]) -> tp.Iterator[np.ndarray]:
logger.info(f"Computing {len(events)} image latents")
dset = _ImageDataset(events, transform=self._make_transform())
dloader = DataLoader(
dset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=_ImageDataset.collate_fn,
)
if len(events) > 1:
dloader = tqdm(dloader, desc="Computing image embeddings") # type: ignore
# Embed the images in batches
with torch.no_grad():
for batch_images in dloader:
if isinstance(batch_images, torch.Tensor):
batch_images = batch_images.to(self.device)
else: # should be list of different sizes
batch_images = [i.to(self.device) for i in batch_images]
with torch.no_grad():
latents = self._extract_batched_latents(batch_images)
for latent in latents:
# notes: - aggregating with a batch would be slightly more efficient
# but code would be messier
# - aggregating in cuda avoids transfering too much data to cpu
latent = self._aggregate_tokens(latent)
yield latent.cpu().numpy()
def get_static(self, event: etypes.Image) -> torch.Tensor:
raise NotImplementedError
def _extract_batched_latents(self, images: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
[docs]
class HuggingFaceImage(BaseImage):
"""Compute image embeddings using transformer-based models obtained through HuggingFace API.
Parameters
----------
model_name : str, default="facebook/dinov2-base"
HuggingFace model identifier.
pretrained : bool, default=True
Whether to load pretrained weights from model. If False, initializes the model with
random weights from the model configuration.
"""
# class attributes
model_name: str = "facebook/dinov2-base"
# extractor attributes
pretrained: bool = True
# for precomputing/caching
infra: MapInfra = MapInfra(version="v5", **CLUSTER_DEFAULTS)
def _exclude_from_cache_uid(self) -> list[str]:
prev = super()._exclude_from_cache_uid()
return prev + ["duration", "frequency"]
@infra.apply(
item_uid=lambda e: str(e.study_relative_path()),
exclude_from_cache_uid="method:_exclude_from_cache_uid",
cache_type="MemmapArrayFile",
)
def _get_data(self, events: tp.Sequence[etypes.Image]) -> tp.Iterator[np.ndarray]:
for latents in super()._get_data(events):
if self.cache_n_layers is None:
latents = self._aggregate_layers(latents)
yield latents
def model_post_init(self, log__):
if self.imsize is not None:
warn_once(
f'The effect of "imsize"={self.imsize} might be cancelled by '
"the HuggingFace processor."
)
super().model_post_init(log__)
@property
def model(self) -> nn.Module:
if not hasattr(self, "_model") or self._model is None:
self._model = _HuggingFace(
model_name=self.model_name,
output_hidden_states=True,
pretrained=self.pretrained,
)
self._model.to(self.device)
return self._model
def _get_hidden_states(self, images: torch.Tensor) -> list[torch.Tensor]:
"""Extract hidden_states as n_layers n_layers x (batch, tokens, features)"""
# this method is overriden in experimental extractors for more hugging face models
out = self.model._full_predict(images) # type: ignore
out = getattr(out, "vision_model_output", out) # for clip
states = out.hidden_states
if states is None:
raise RuntimeError(
f"Model {self.model_name!r} returned hidden_states=None. "
"This is a known regression in transformers>=5 where some "
"encoders (CLIP, SAM, ViT) no longer collect intermediate "
"hidden states. See RELEASE_PLAN.md for the tracking item."
)
return states # type: ignore
def _extract_batched_latents(self, images: torch.Tensor) -> torch.Tensor:
states = self._get_hidden_states(images)
out = torch.cat([x.unsqueeze(1) for x in states], axis=1) # type: ignore
# (batch, n_layers, tokens, n_features)
return out # type: ignore
def get_static(self, event: etypes.Image) -> torch.Tensor:
# layer * patches * size
latent = next(self._get_data([event]))
latent = np.array(latent, copy=False) # make sure it's loaded from memmap
if self.cache_n_layers is not None:
latent = self._aggregate_layers(latent)
# copy needed: memmap arrays are read-only
return torch.Tensor(np.array(latent, copy=True))
class BaseClassicImageExtractor(BaseStatic):
"""Base class for classic image extractors, e.g. based on numpy, skimage, OpenCV, etc.
Parameters
----------
imsize : int | None, default to None
Optionally resize images to imsize before passing them to the model. If None,
use the original image size.
"""
# class attributes
event_types: tp.Literal["Image"] = "Image"
imsize: int | None = None
infra: MapInfra = MapInfra(version="v5", **CLUSTER_DEFAULTS)
@infra.apply(
item_uid=lambda event: str(event.study_relative_path()),
cache_type="MemmapArrayFile",
)
def _get_data(self, events: list[etypes.Image]) -> tp.Iterator[np.ndarray]:
logger.info("Computing %s for %s images.", type(self).__name__, len(events))
for event in events:
image = event.read()
if self.imsize is not None:
image = image.resize((self.imsize, self.imsize))
yield self._get_image_features(np.array(image))
def _get_image_features(self, image: np.ndarray) -> np.ndarray:
raise NotImplementedError
def get_static(self, event: etypes.Image) -> torch.Tensor:
return torch.Tensor(np.asarray(next(self._get_data([event]))))
[docs]
class RFFT2D(BaseClassicImageExtractor):
"""(Cropped) 2D Fourier spectrum of an image of real values.
Parameters
----------
n_components_to_keep :
Number of components of the FFT to keep, starting from low frequencies and
moving towards higher frequencies. If None, use all components.
average_channels :
If True, average RGB channels before taking the FFT (to reduce dimensionality).
return_log_psd :
If True, return the flattened log PSD instead of the "viewed-as-real" complex FFT.
return_angle :
If True, return the flattened angle. Can be combined with the log PSD.
"""
requirements: tp.ClassVar[tuple[str, ...]] = ("torchvision>=0.15.2",)
n_components_to_keep: int | None = None
average_channels: bool = True
return_log_psd: bool = False
return_angle: bool = False
_eps: tp.ClassVar[float] = 1e-12
def _fft(self, image: torch.Tensor) -> torch.Tensor:
fft = torch.fft.rfft2(image)
if self.average_channels:
fft = fft.mean(axis=0, keepdims=True)
fft = torch.fft.fftshift(fft, dim=1)
if self.n_components_to_keep is not None: # Crop FFT by keeping lower frequencies
mid_point_x = fft.shape[1] // 2
n = self.n_components_to_keep
lo = mid_point_x - n
hi = mid_point_x + n
fft = fft[:, lo:hi, : n + 1]
return fft
@staticmethod
def _ifft(
fft: torch.Tensor, average_channels: bool, width: int, height: int
) -> torch.Tensor:
"""Convenience function to return in image-space after an FFT.
Only supports "viewed as real" FFT.
"""
if fft.ndim == 1:
fft = fft.reshape( # Unflatten and convert back to complex
1 if average_channels else 3,
width,
height // 2 + 1,
2,
)
fft = torch.view_as_complex(fft)
fft = torch.fft.ifftshift(fft, dim=1)
inv_fft = torch.fft.irfft2(fft).real
inv_fft = inv_fft / inv_fft.max()
return inv_fft
def _get_image_features(self, image: np.ndarray) -> torch.Tensor:
import torchvision.transforms.functional as TF # noqa
fft = self._fft(TF.to_tensor(image))
out = []
if self.return_log_psd:
out.append((fft.abs() ** 2 + self._eps).log())
if self.return_angle:
out.append(fft.angle())
if not (self.return_log_psd or self.return_angle):
out.append(torch.view_as_real(fft)) # Complex tensor -> Real vector
features = torch.cat(out, dim=-1).flatten()
return features
[docs]
class HOG(BaseClassicImageExtractor):
"""Histogram of oriented gradients (Dalal & Triggs, 2005).
See https://scikit-image.org/docs/stable/auto_examples/features_detection/plot_hog.html
References
----------
.. [1] Dalal, N. and Triggs, B., "Histograms of Oriented Gradients for
Human Detection," IEEE Computer Society Conference on Computer Vision
and Pattern Recognition, 2005, San Diego, CA, USA.
https://ieeexplore.ieee.org/document/1467360
"""
requirements: tp.ClassVar[tuple[str, ...]] = ("scikit-image>=0.22.0",)
_orientations: tp.ClassVar[int] = 8
_pixels_per_cell: tp.ClassVar[tuple[int, int]] = (8, 8)
_cells_per_block: tp.ClassVar[tuple[int, int]] = (2, 2)
_channel_axis: tp.ClassVar[int] = -1
def _get_image_features(self, image: np.ndarray) -> np.ndarray:
from skimage.feature import hog # noqa
features = hog(
image,
orientations=self._orientations,
pixels_per_cell=self._pixels_per_cell,
cells_per_block=self._cells_per_block,
channel_axis=self._channel_axis,
visualize=False,
)
return features
[docs]
class LBP(BaseClassicImageExtractor):
"""Local Binary Pattern (LBP).
See https://scikit-image.org/docs/stable/auto_examples/features_detection/plot_local_binary_pattern.html
"""
requirements: tp.ClassVar[tuple[str, ...]] = (
"opencv-python>=4.8.1",
"scikit-image>=0.22.0",
)
_P: tp.ClassVar[int] = 8
_R: tp.ClassVar[int] = 1
_method: tp.ClassVar[str] = "uniform"
_n_bins: tp.ClassVar[int] = 10
_bin_range: tp.ClassVar[tuple[int, int]] = (0, 10)
def _get_image_features(self, image: np.ndarray) -> np.ndarray:
import cv2 # noqa
from skimage.feature import local_binary_pattern # noqa
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # requires grayscale
lbp = local_binary_pattern(gray, P=self._P, R=self._R, method=self._method)
hist, _ = np.histogram(lbp.ravel(), bins=self._n_bins, range=self._bin_range)
hist = hist.astype("float")
hist /= hist.sum() + 1e-7
return hist
[docs]
class ColorHistogram(BaseClassicImageExtractor):
"""Color histogram.
See https://docs.opencv.org/3.4/d8/dbc/tutorial_histogram_calculation.html
"""
requirements: tp.ClassVar[tuple[str, ...]] = ("opencv-python>=4.8.1",)
_channels: tp.ClassVar[tuple[int, ...]] = (0, 1, 2)
_hist_size: tp.ClassVar[tuple[int, ...]] = (8, 8, 8)
_ranges: tp.ClassVar[tuple[int, ...]] = (0, 256, 0, 256, 0, 256)
def _get_image_features(self, image: np.ndarray) -> np.ndarray:
import cv2 # noqa
hist = cv2.calcHist([image], self._channels, None, self._hist_size, self._ranges)
hist = cv2.normalize(hist, hist).flatten()
return hist