Source code for neuralset.extractors.image

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import typing as tp

import numpy as np
import torch
from exca import MapInfra
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from neuralset import base, utils
from neuralset.events import etypes

from . import base as extractor_base
from . import hf

logger = logging.getLogger(__name__)
CLUSTER_DEFAULTS: dict[str, tp.Any] = dict(
    timeout_min=25,
    gpus_per_node=1,
    cpus_per_task=8,
    min_samples_per_job=4096,
)


def _fix_pixel_values(inputs: dict[str, tp.Any]) -> None:
    # prevent nans (happening for uniform images)
    if "pixel_values" in inputs:
        nans = inputs["pixel_values"].isnan()
        if nans.any():
            inputs["pixel_values"][nans] = 0
            inputs["pixel_values"] = inputs["pixel_values"].float()


class _ImageDataset(Dataset):
    """PyTorch Dataset for loading and transforming image events.

    This dataset wraps a sequence of image events and applies optional transformations
    to each image when accessed.
    """

    def __init__(self, events: tp.Sequence[etypes.Image], transform=None):
        self.events = events
        self.transform = transform

    def __len__(self) -> int:
        return len(self.events)

    def __getitem__(self, idx: int):
        try:
            image = self.events[idx].read()
            if self.transform:
                image = self.transform(image)
        except:
            logger.warning("Failed to process image event %s", self.events[idx])
            raise
        return image

    @staticmethod
    def collate_fn(images: list[torch.Tensor]) -> tp.Any:
        # we can't concatenate if the outputs have different sizes
        # for huggingface -> transform is applied later
        if all(i.shape == images[0].shape for i in images):
            return torch.stack(images)
        return images


class _VideoImage(etypes.Image):
    """Image event wrapper for extracting individual frames from a video."""

    start: float = 0.0
    timeline: str = "fake"
    duration: float = 1.0
    video: tp.Any
    time: float = 0.0
    filepath: str = ""

    def model_post_init(self, log__: tp.Any) -> None:
        if self.filepath:
            raise ValueError("Filepath is automatically filled")
        self.filepath = f"{self.video.filename}:{self.time:.3f}"
        super().model_post_init(log__)

    def _read(self) -> tp.Any:
        import PIL  # noqa

        with utils.ignore_all():
            img = self.video.get_frame(self.time)
        return PIL.Image.fromarray(img.astype("uint8"))


def _huggingface_image_event_uid(event: etypes.Image | etypes.Video) -> str:
    if isinstance(event, etypes.Video):
        return event._splittable_event_uid()
    return str(event.study_relative_path())


class HuggingFaceImageConfig(hf.HuggingFaceConfig):
    processor_kwargs: dict[str, tp.Any] | None = {"do_rescale": False}
    HF_CLASS_DEFAULTS: tp.ClassVar[dict[str, dict[str, str]]] = {
        "clip": {
            "model_cls_name": "CLIPModel",
            "processor_cls_name": "CLIPProcessor",
        },
        "dinov2": {
            "model_cls_name": "Dinov2Model",
            "processor_cls_name": "AutoImageProcessor",
        },
    }


[docs] class HuggingFaceImage(extractor_base.BaseStatic, hf.HuggingFaceMixin): """Compute image embeddings using transformer-based models obtained through HuggingFace API. Parameters ---------- model_name : str, default="facebook/dinov2-base" HuggingFace model identifier. """ # class attributes event_types: tp.Literal["Image", "Video"] = "Image" requirements: tp.ClassVar[tuple[str, ...]] = ( "torchvision>=0.15.2", "transformers>=4.29.2", "pillow>=9.2.0", ) model_name: str = "facebook/dinov2-base" hf_config: HuggingFaceImageConfig = HuggingFaceImageConfig() # for precomputing/caching infra: MapInfra = MapInfra(version="v6", **CLUSTER_DEFAULTS) batch_size: int = 32 imsize: int | None = None frequency: float | tp.Literal["native"] = 0.0 # type: ignore[assignment] @classmethod def _exclude_from_cls_uid(cls) -> list[str]: return ( ["batch_size"] + extractor_base.BaseStatic._exclude_from_cls_uid() + hf.HuggingFaceMixin._exclude_from_cls_uid() ) def _exclude_from_cache_uid(self) -> list[str]: return extractor_base.BaseStatic._exclude_from_cache_uid( self ) + hf.HuggingFaceMixin._exclude_from_cache_uid(self) def _iter_image_latents( self, events: tp.Sequence[etypes.Image], aggregate_layers: bool ) -> tp.Iterator[np.ndarray]: from torchvision import transforms logger.info(f"Computing {len(events)} image latents") transfs = [transforms.ToTensor()] if self.imsize is not None: transfs = [transforms.Resize(self.imsize)] + transfs dset = _ImageDataset(events, transform=transforms.Compose(transfs)) dloader = DataLoader( dset, batch_size=self.batch_size, shuffle=False, collate_fn=_ImageDataset.collate_fn, ) if len(events) > 1: dloader = tqdm(dloader, desc="Computing image embeddings") # type: ignore # Embed the images in batches with torch.no_grad(): for batch_images in dloader: if isinstance(batch_images, torch.Tensor): batch_images = batch_images.to(self.model_device) else: # should be list of different sizes batch_images = [i.to(self.model_device) for i in batch_images] with torch.no_grad(): latents = self._extract_batched_latents(batch_images) for latent in latents: # notes: - aggregating with a batch would be slightly more efficient # but code would be messier # - aggregating in cuda avoids transferring too much data to cpu latent = self._aggregate_tokens(latent) if aggregate_layers: latent = self._aggregate_layers(latent) yield latent.cpu().numpy() @infra.apply( item_uid=_huggingface_image_event_uid, exclude_from_cache_uid="method:_exclude_from_cache_uid", cache_type="MemmapArrayFile", ) def _get_data( self, events: tp.Sequence[etypes.Image | etypes.Video] ) -> tp.Iterator[np.ndarray]: if self.event_types == "Video": for event in tp.cast(tp.Sequence[etypes.Video], events): yield self._get_video_data(event) return yield from self._iter_image_latents( tp.cast(tp.Sequence[etypes.Image], events), aggregate_layers=self.cache_n_layers is None, ) def _get_video_data(self, event: etypes.Video) -> np.ndarray: if self.frequency == 0: msg = "HuggingFaceImage requires frequency='native' or a positive frequency for Video events." raise ValueError(msg) video = event.read() try: freq = event.frequency if self.frequency == "native" else self.frequency expect_frames = max(1, base.Frequency(freq).to_ind(event.duration)) times = np.linspace(0, video.duration, expect_frames + 1)[1:] frames = [_VideoImage(video=video, time=float(t)) for t in times] embeddings = [] for embd in self._iter_image_latents( frames, aggregate_layers=self.cache_n_layers is None, ): embeddings.append(np.asarray(embd)) output = np.stack(embeddings, axis=0) output = output.transpose(list(range(1, output.ndim)) + [0]) return output.astype(np.float32) finally: video.close() def model_post_init(self, log__): if self.imsize is not None: utils.warn_once( f'The effect of "imsize"={self.imsize} might be cancelled by ' "the HuggingFace processor." ) super().model_post_init(log__) def _full_predict( # return the raw output, used in tests self, images: torch.Tensor, text: str | list[str] = "" ) -> tp.Any: kwargs: dict[str, tp.Any] = dict( images=[i.float() for i in images], return_tensors="pt" ) if text: kwargs["text"] = text inputs = self.processor(**kwargs) _fix_pixel_values(inputs) inputs = inputs.to(self.model_device) with torch.inference_mode(): return self.model(**inputs, output_hidden_states=True) def _extract_batched_latents(self, images: torch.Tensor) -> torch.Tensor: out = self._full_predict(images) out = getattr(out, "vision_model_output", out) # for clip states = out.hidden_states if states is None: raise RuntimeError( f"Model {self.model_name!r} returned hidden_states=None. " "This is a known regression in transformers>=5 where some " "encoders (CLIP, SAM, ViT) no longer collect intermediate " "hidden states." ) out = torch.cat([x.unsqueeze(1) for x in states], axis=1) # type: ignore # (batch, n_layers, tokens, n_features) return out # type: ignore def _get_timed_arrays( self, events: list[etypes.Image | etypes.Video], start: float, duration: float, ) -> tp.Iterable[base.TimedArray]: if self.event_types == "Video": video_events = tp.cast(list[etypes.Video], events) for event, latents in zip(video_events, self._get_data(video_events)): freq = event.frequency if self.frequency == "native" else self.frequency tarray = base.TimedArray( data=np.asarray(latents), frequency=freq, start=base._UNSET_START, duration=event.duration, ) sub = tarray.with_start(event.start).overlap( start=start, duration=duration ) if self.cache_n_layers is not None: sub.data = self._aggregate_layers(sub.data) yield sub elif self.event_types == "Image": for image_event, latents in zip(events, self._get_data(events)): if self.cache_n_layers is not None: latents = self._aggregate_layers(latents) yield base.TimedArray( frequency=0, duration=image_event.duration, start=image_event.start, data=np.asarray(latents), ) return else: msg = f"Unsupported event_types={self.event_types!r} for HuggingFaceImage" raise ValueError(msg)
[docs] def get_static(self, event: etypes.Image) -> torch.Tensor: if self.event_types == "Video": raise TypeError("Use HuggingFaceImage.__call__ for Video events.") # layer * patches * size latent = next(self._get_data([event])) latent = np.array(latent, copy=False) # make sure it's loaded from memmap if self.cache_n_layers is not None: latent = self._aggregate_layers(latent) # copy needed: memmap arrays are read-only return torch.Tensor(np.array(latent, copy=True))
class BaseClassicImageExtractor(extractor_base.BaseStatic): """Base class for classic image extractors, e.g. based on numpy, skimage, OpenCV, etc. Parameters ---------- imsize : int | None, default to None Optionally resize images to imsize before passing them to the model. If None, use the original image size. """ # class attributes event_types: tp.Literal["Image"] = "Image" imsize: int | None = None infra: MapInfra = MapInfra(version="v5", **CLUSTER_DEFAULTS) @infra.apply( item_uid=lambda event: str(event.study_relative_path()), cache_type="MemmapArrayFile", ) def _get_data(self, events: list[etypes.Image]) -> tp.Iterator[np.ndarray]: logger.info("Computing %s for %s images.", type(self).__name__, len(events)) for event in events: image = event.read() if self.imsize is not None: image = image.resize((self.imsize, self.imsize)) yield self._get_image_features(np.array(image)) def _get_image_features(self, image: np.ndarray) -> np.ndarray: raise NotImplementedError def get_static(self, event: etypes.Image) -> torch.Tensor: return torch.Tensor(np.asarray(next(self._get_data([event]))))
[docs] class RFFT2D(BaseClassicImageExtractor): """(Cropped) 2D Fourier spectrum of an image of real values. Parameters ---------- n_components_to_keep : Number of components of the FFT to keep, starting from low frequencies and moving towards higher frequencies. If None, use all components. average_channels : If True, average RGB channels before taking the FFT (to reduce dimensionality). return_log_psd : If True, return the flattened log PSD instead of the "viewed-as-real" complex FFT. return_angle : If True, return the flattened angle. Can be combined with the log PSD. """ requirements: tp.ClassVar[tuple[str, ...]] = ("torchvision>=0.15.2",) n_components_to_keep: int | None = None average_channels: bool = True return_log_psd: bool = False return_angle: bool = False _eps: tp.ClassVar[float] = 1e-12 def _fft(self, image: torch.Tensor) -> torch.Tensor: fft = torch.fft.rfft2(image) if self.average_channels: fft = fft.mean(axis=0, keepdims=True) fft = torch.fft.fftshift(fft, dim=1) if self.n_components_to_keep is not None: # Crop FFT by keeping lower frequencies mid_point_x = fft.shape[1] // 2 n = self.n_components_to_keep lo = mid_point_x - n hi = mid_point_x + n fft = fft[:, lo:hi, : n + 1] return fft @staticmethod def _ifft( fft: torch.Tensor, average_channels: bool, width: int, height: int ) -> torch.Tensor: """Convenience function to return in image-space after an FFT. Only supports "viewed as real" FFT. """ if fft.ndim == 1: fft = fft.reshape( # Unflatten and convert back to complex 1 if average_channels else 3, width, height // 2 + 1, 2, ) fft = torch.view_as_complex(fft) fft = torch.fft.ifftshift(fft, dim=1) inv_fft = torch.fft.irfft2(fft).real inv_fft = inv_fft / inv_fft.max() return inv_fft def _get_image_features(self, image: np.ndarray) -> torch.Tensor: import torchvision.transforms.functional as TF # noqa fft = self._fft(TF.to_tensor(image)) out = [] if self.return_log_psd: out.append((fft.abs() ** 2 + self._eps).log()) if self.return_angle: out.append(fft.angle()) if not (self.return_log_psd or self.return_angle): out.append(torch.view_as_real(fft)) # Complex tensor -> Real vector features = torch.cat(out, dim=-1).flatten() return features
[docs] class HOG(BaseClassicImageExtractor): """Histogram of oriented gradients (Dalal & Triggs, 2005). See https://scikit-image.org/docs/stable/auto_examples/features_detection/plot_hog.html References ---------- .. [1] Dalal, N. and Triggs, B., "Histograms of Oriented Gradients for Human Detection," IEEE Computer Society Conference on Computer Vision and Pattern Recognition, 2005, San Diego, CA, USA. https://ieeexplore.ieee.org/document/1467360 """ requirements: tp.ClassVar[tuple[str, ...]] = ("scikit-image>=0.22.0",) _orientations: tp.ClassVar[int] = 8 _pixels_per_cell: tp.ClassVar[tuple[int, int]] = (8, 8) _cells_per_block: tp.ClassVar[tuple[int, int]] = (2, 2) _channel_axis: tp.ClassVar[int] = -1 def _get_image_features(self, image: np.ndarray) -> np.ndarray: from skimage.feature import hog # noqa features = hog( image, orientations=self._orientations, pixels_per_cell=self._pixels_per_cell, cells_per_block=self._cells_per_block, channel_axis=self._channel_axis, visualize=False, ) return features
[docs] class LBP(BaseClassicImageExtractor): """Local Binary Pattern (LBP). See https://scikit-image.org/docs/stable/auto_examples/features_detection/plot_local_binary_pattern.html """ requirements: tp.ClassVar[tuple[str, ...]] = ( "opencv-python>=4.8.1", "scikit-image>=0.22.0", ) _P: tp.ClassVar[int] = 8 _R: tp.ClassVar[int] = 1 _method: tp.ClassVar[str] = "uniform" _n_bins: tp.ClassVar[int] = 10 _bin_range: tp.ClassVar[tuple[int, int]] = (0, 10) def _get_image_features(self, image: np.ndarray) -> np.ndarray: import cv2 # noqa from skimage.feature import local_binary_pattern # noqa gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # requires grayscale lbp = local_binary_pattern(gray, P=self._P, R=self._R, method=self._method) hist, _ = np.histogram(lbp.ravel(), bins=self._n_bins, range=self._bin_range) hist = hist.astype("float") hist /= hist.sum() + 1e-7 return hist
[docs] class ColorHistogram(BaseClassicImageExtractor): """Color histogram. See https://docs.opencv.org/3.4/d8/dbc/tutorial_histogram_calculation.html """ requirements: tp.ClassVar[tuple[str, ...]] = ("opencv-python>=4.8.1",) _channels: tp.ClassVar[tuple[int, ...]] = (0, 1, 2) _hist_size: tp.ClassVar[tuple[int, ...]] = (8, 8, 8) _ranges: tp.ClassVar[tuple[int, ...]] = (0, 256, 0, 256, 0, 256) def _get_image_features(self, image: np.ndarray) -> np.ndarray: import cv2 # noqa hist = cv2.calcHist([image], self._channels, None, self._hist_size, self._ranges) hist = cv2.normalize(hist, hist).flatten() return hist