Source code for neuralbench.modules

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""``nn.Module`` classes and pydantic config wrappers used by neuralbench.

This module groups the small custom ``nn.Module`` building blocks
(``IndexSelect``, ``Mean``, ``ConcatGroupedMean``) together with the
pydantic configs that build trainable adapters / probes on top of a
pretrained model (``ChannelProjection``, ``DownstreamWrapper``) and the
runtime wrapper they produce (``DownstreamWrapperModel``).
"""

import inspect
import logging
import typing as tp

import pydantic
import torch
from torch import nn

from neuraltrain.models.common import ChannelMerger, Mlp
from neuraltrain.models.preprocessor import OnTheFlyPreprocessor

LOGGER = logging.getLogger(__name__)


class IndexSelect(nn.Module):
    """Select specific indices along a dimension, squeezing if only one index is selected."""

    def __init__(self, dim: int, index: torch.Tensor) -> None:
        super().__init__()
        self.dim = dim
        self.index = index

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.index_select(x, dim=self.dim, index=self.index.to(x.device))
        if len(self.index) == 1:
            out = out.squeeze(self.dim)
        return out


class Mean(nn.Module):
    """Reduce a tensor by averaging over one or more dimensions."""

    def __init__(self, dim: int | tuple[int, ...]) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mean(dim=self.dim)


class ConcatGroupedMean(nn.Module):
    """Split a tensor into ``n_splits`` groups along ``dim``, average each group, then concatenate."""

    def __init__(self, dim: int, n_splits: int) -> None:
        super().__init__()
        self.dim = dim
        self.n_splits = n_splits

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.concat(
            [
                xi.mean(dim=self.dim)
                for xi in torch.tensor_split(x, self.n_splits, dim=self.dim)
            ],
            dim=self.dim,
        )


class ChannelProjection(pydantic.BaseModel):
    """Configuration for a Conv1d(kernel_size=1) channel projection adapter.

    Projects from an arbitrary number of input channels to a fixed target count
    via a learned pointwise (1x1) convolution.  Unlike ``ChannelMerger``, this
    does not use channel positions -- it is a simple linear mixing matrix applied
    identically at every time step.

    Parameters
    ----------
    n_target_channels : int
        Number of output channels (must match what the pretrained model expects).
    max_norm : float or None
        If set, applies a max-norm weight constraint on the Conv1d kernel
        (following braindecode's ``Conv1dWithConstraint``). Default is 1.0.
        Set to None to disable the constraint.
    init : {"random", "identity", "bipolar"}
        Initialisation scheme for the Conv1d kernel.

        * ``"random"`` (default): PyTorch's default Kaiming-uniform
          initialisation.
        * ``"identity"``: rows of the kernel that correspond to target channel
          names present (possibly under a rename) in the input channel list
          are set to a one-hot vector selecting that input; all other rows are
          zero.  On step 0 this makes the adapter a pass-through for the
          matching channels and a zero pad for missing ones, mirroring the
          behaviour of the zero-fill channel remapping that the adapter
          replaces.  Requires :attr:`target_channel_names` and, at build time,
          ``input_channel_names``.
        * ``"bipolar"``: rows correspond to bipolar derivations specified as
          ``"A-B"`` strings in ``target_channel_names`` (unipolar entries
          ``"A"`` are also accepted and treated as ``"A-None"``, i.e. +1 only
          on A).  For each covered pair, +1 is added to ``weight[target,
          idx(A), 0]`` and -1 to ``weight[target, idx(B), 0]`` **on top of**
          the default Kaiming-uniform init (additive, not overwriting).
          Rows whose pair is partially or fully missing retain the Kaiming
          baseline so they can still learn -- important for models whose
          first op is ``|STFT|`` (e.g. BIOT), where a fully zero row would
          yield a zero gradient through the ``abs`` singularity.  Requires
          :attr:`target_channel_names` and, at build time,
          ``input_channel_names``.
    target_channel_names : list[str] or None
        Required for ``init="identity"`` and ``init="bipolar"``.  Names of the
        target channels in the order the pretrained model expects them.  Must
        have length ``n_target_channels``.  Matching against input channel
        names is case-insensitive.  For ``init="bipolar"`` entries may be
        bipolar pairs ``"A-B"``; entries without ``-`` are treated as
        unipolar.
    rename_mapping : dict[str, str] or None
        Optional canonicalisation map applied to input channel names before
        identity/bipolar-init matching.  E.g. ``{"T3": "T7", "E9": "Fp2"}``.
        Ignored when ``init="random"``.
    """

    model_config = pydantic.ConfigDict(extra="forbid")
    n_target_channels: int
    max_norm: float | None = 1.0
    init: tp.Literal["random", "identity", "bipolar"] = "random"
    target_channel_names: list[str] | None = None
    rename_mapping: dict[str, str] | None = None

    def model_post_init(self, __context):
        super().model_post_init(__context)
        if self.init in ("identity", "bipolar"):
            if self.target_channel_names is None:
                raise ValueError(
                    f"init={self.init!r} requires target_channel_names to be set."
                )
            if len(self.target_channel_names) != self.n_target_channels:
                raise ValueError(
                    f"target_channel_names has length {len(self.target_channel_names)} "
                    f"but n_target_channels is {self.n_target_channels}."
                )

    def build(
        self,
        n_in_channels: int,
        input_channel_names: list[str] | None = None,
    ) -> nn.Module:
        if self.max_norm is not None:
            from braindecode.modules import Conv1dWithConstraint

            conv: nn.Conv1d = Conv1dWithConstraint(
                n_in_channels,
                self.n_target_channels,
                kernel_size=1,
                max_norm=self.max_norm,
            )
        else:
            conv = nn.Conv1d(n_in_channels, self.n_target_channels, kernel_size=1)

        if self.init == "identity":
            self._apply_identity_init(conv, input_channel_names)
        elif self.init == "bipolar":
            self._apply_bipolar_init(conv, input_channel_names)

        return conv

    def _apply_identity_init(
        self,
        conv: nn.Conv1d,
        input_channel_names: list[str] | None,
    ) -> None:
        """Overwrite ``conv`` weights with a name-matched identity pattern.

        For each target channel name, finds the (possibly renamed) input
        channel whose name matches and sets ``weight[target, input, 0] = 1``.
        All other weights and the bias are zero.
        """
        if input_channel_names is None:
            raise ValueError(
                "init='identity' requires input_channel_names at build time; "
                "the calling DownstreamWrapper must forward them."
            )
        assert self.target_channel_names is not None  # enforced in post-init
        n_in = conv.weight.shape[1]
        if len(input_channel_names) != n_in:
            raise ValueError(
                f"input_channel_names has length {len(input_channel_names)} "
                f"but conv expects {n_in} input channels."
            )

        rename = self.rename_mapping or {}
        canon_inputs = [rename.get(name, name).upper() for name in input_channel_names]
        target_upper_to_idx: dict[str, int] = {}
        for i, tname in enumerate(self.target_channel_names):
            target_upper_to_idx.setdefault(tname.upper(), i)

        weight = torch.zeros_like(conv.weight)
        covered: set[int] = set()
        for j, canon in enumerate(canon_inputs):
            target_idx = target_upper_to_idx.get(canon)
            if target_idx is None:
                continue
            if target_idx in covered:
                LOGGER.warning(
                    "ChannelProjection identity init: target channel %r is "
                    "already covered by another input; input %r (%r) "
                    "contributes additively.",
                    self.target_channel_names[target_idx],
                    input_channel_names[j],
                    canon,
                )
            weight[target_idx, j, 0] = 1.0
            covered.add(target_idx)

        missing = [
            self.target_channel_names[i]
            for i in range(self.n_target_channels)
            if i not in covered
        ]
        LOGGER.info(
            "ChannelProjection identity init: %d/%d target channels covered "
            "(missing: %s); %d/%d input channels contribute.",
            len(covered),
            self.n_target_channels,
            missing if missing else "none",
            sum(1 for canon in canon_inputs if canon in target_upper_to_idx),
            n_in,
        )

        with torch.no_grad():
            _raw_conv_weight(conv).copy_(weight)
            if conv.bias is not None:
                conv.bias.zero_()

    def _apply_bipolar_init(
        self,
        conv: nn.Conv1d,
        input_channel_names: list[str] | None,
    ) -> None:
        """Add a name-matched bipolar pattern on top of the Kaiming baseline.

        For each target entry parsed as ``"A-B"`` (or unipolar ``"A"``), finds
        the (possibly renamed) input channels for A and B and **adds** +1 /
        -1 to ``weight[target, idx(A), 0]`` / ``weight[target, idx(B), 0]`` on
        top of the default Kaiming-uniform init produced by
        ``nn.Conv1d.__init__``.  Rows whose pair is partially or fully
        missing keep the Kaiming baseline.

        This additive choice is critical for models whose first op is a
        magnitude spectrogram (e.g. BIOT's ``|torch.stft(x)|``): an
        identically-zero row would produce a zero output, and the ``abs``
        singularity at zero would freeze the row's gradient, so the adapter
        could never recover.  The small Kaiming noise lets every row learn
        while the +1/-1 pattern preserves the pretrained channel-token
        semantics on covered pairs at step 0.
        """
        if input_channel_names is None:
            raise ValueError(
                "init='bipolar' requires input_channel_names at build time; "
                "the calling DownstreamWrapper must forward them."
            )
        assert self.target_channel_names is not None  # enforced in post-init
        n_in = conv.weight.shape[1]
        if len(input_channel_names) != n_in:
            raise ValueError(
                f"input_channel_names has length {len(input_channel_names)} "
                f"but conv expects {n_in} input channels."
            )

        rename = self.rename_mapping or {}
        canon_inputs = [rename.get(name, name).upper() for name in input_channel_names]
        canon_to_idx: dict[str, int] = {}
        for j, canon in enumerate(canon_inputs):
            canon_to_idx.setdefault(canon, j)

        pattern = torch.zeros_like(conv.weight)
        fully_covered: list[str] = []
        partial: list[str] = []
        missing: list[str] = []
        for i, tname in enumerate(self.target_channel_names):
            if "-" in tname:
                pos_name, neg_name = tname.split("-", 1)
            else:
                pos_name, neg_name = tname, None
            pos_idx = canon_to_idx.get(pos_name.upper())
            neg_idx = canon_to_idx.get(neg_name.upper()) if neg_name else None

            if neg_name is None:
                if pos_idx is not None:
                    pattern[i, pos_idx, 0] = 1.0
                    fully_covered.append(tname)
                else:
                    missing.append(tname)
            else:
                if pos_idx is not None and neg_idx is not None:
                    pattern[i, pos_idx, 0] = 1.0
                    pattern[i, neg_idx, 0] = -1.0
                    fully_covered.append(tname)
                elif pos_idx is not None or neg_idx is not None:
                    partial.append(tname)
                else:
                    missing.append(tname)

        LOGGER.info(
            "ChannelProjection bipolar init: %d/%d fully covered (+1/-1 added "
            "over Kaiming baseline), %d partial (kept at Kaiming baseline so "
            "they can still learn), %d missing (kept at Kaiming baseline). "
            "Covered: %s. Partial: %s. Missing: %s.",
            len(fully_covered),
            self.n_target_channels,
            len(partial),
            len(missing),
            fully_covered if fully_covered else "none",
            partial if partial else "none",
            missing if missing else "none",
        )

        with torch.no_grad():
            # Additive on top of the Kaiming-uniform init that
            # ``nn.Conv1d.__init__`` already applied.
            _raw_conv_weight(conv).add_(pattern)
            if conv.bias is not None:
                conv.bias.zero_()


def _raw_conv_weight(conv: nn.Conv1d) -> torch.Tensor:
    """Return the underlying trainable weight tensor for ``conv``.

    ``Conv1dWithConstraint`` (and any other ``torch.nn.utils.parametrize``
    user) registers parametrizations on ``weight`` -- reading ``conv.weight``
    yields the parametrised *view* and writing to it is rejected unless the
    parametrization defines a ``right_inverse``.  This helper returns the
    raw, trainable tensor so callers can mutate it in-place during custom
    initialisation.
    """
    parametrizations = getattr(conv, "parametrizations", None)
    if parametrizations is not None and "weight" in parametrizations:
        return parametrizations["weight"].original  # type: ignore[no-any-return]
    return conv.weight


[docs] class DownstreamWrapper(pydantic.BaseModel): """Configuration for wrapping a (pretrained) model for downstream fine-tuning or linear probing. This class provides a declarative way to configure how a pretrained model should be adapted for downstream tasks, including optional on-the-fly preprocessing, layer freezing, output aggregation, and adding a trainable probe on top of the model. Parameters ---------- on_the_fly_preprocessor : OnTheFlyPreprocessor | None, optional On-the-fly preprocessing applied to the input before the model forward pass. Typically model-specific (e.g. QuantileAbsScaler for BIOT). Default is None. channel_adapter_config : ChannelMerger | ChannelProjection | None, optional Configuration for a channel adapter that projects from arbitrary input channels to a fixed number of target channels. Supply a ``ChannelMerger`` for position-based spatial attention, or a ``ChannelProjection`` for a simple Conv1d(kernel_size=1) linear mixing. Default is None. model_output_key : str | int | None, optional Key or index to extract from model output dictionary. If None, assumes the model returns a tensor directly. Default is None. layers_to_freeze : list[str] | None, optional List of layer name patterns to freeze (set requires_grad=False). Cannot be used together with layers_to_unfreeze. Default is None. layers_to_unfreeze : list[str] | tp.Literal["last"] | None, optional List of layer name patterns to unfreeze (set requires_grad=True), while freezing all others. Cannot be used together with layers_to_freeze. If "last", unfreezes the last layer (nn.Module) of the model. Default is None. strict_matching : bool, optional If True, when freezing/unfreezing layers, only the first part of the layer name (before the first dot) must match exactly. If False, any part of the layer name can match the patterns. Default is True. aggregation : {"flatten", "mean", "first"} or int, optional Method to aggregate the model output. ``"flatten"`` flattens all dimensions except batch; ``"mean"`` averages over the temporal/sequence dimension (dim=1); ``"first"`` selects only the first timestep/token; an ``int`` splits into n groups, averages each group, then concatenates; ``None`` performs no aggregation. probe_config : Mlp | "linear" | None, optional Configuration for the probe layer added on top. ``None`` uses identity (no additional layer), e.g. if the model already has a linear layer of the right output size. ``"linear"`` adds a single linear layer. An ``Mlp`` instance adds a multi-layer perceptron with specified configuration. """ model_config = pydantic.ConfigDict(extra="forbid") on_the_fly_preprocessor: OnTheFlyPreprocessor | None = None channel_adapter_config: ChannelMerger | ChannelProjection | None = None model_output_key: str | int | None = None layers_to_freeze: list[str] | None = None layers_to_unfreeze: list[str] | tp.Literal["last"] | None = None strict_matching: bool = True aggregation: tp.Literal["flatten", "mean", "first"] | int | None = "flatten" probe_config: Mlp | tp.Literal["linear"] | None = "linear" @property def n_adapter_target_channels(self) -> int | None: """Target channel count of the adapter, or ``None`` if no adapter is configured.""" if self.channel_adapter_config is None: return None if isinstance(self.channel_adapter_config, ChannelMerger): return self.channel_adapter_config.n_virtual_channels return self.channel_adapter_config.n_target_channels def model_post_init(self, __context): super().model_post_init(__context) if self.layers_to_freeze is not None and self.layers_to_unfreeze is not None: raise ValueError( "Only one of layers_to_freeze and layers_to_unfreeze can be specified at once." )
[docs] def build( self, model: nn.Module, dummy_batch: dict[str, torch.Tensor | None], n_outputs: int, input_channel_names: list[str] | None = None, ) -> "DownstreamWrapperModel": preprocessor = None if self.on_the_fly_preprocessor is not None: preprocessor = self.on_the_fly_preprocessor.build() channel_adapter: nn.Module | None = None adapter_needs_positions = False if self.channel_adapter_config is not None: if isinstance(self.channel_adapter_config, ChannelMerger): channel_adapter = self.channel_adapter_config.build() adapter_needs_positions = True else: input_key = next(iter(dummy_batch)) x = dummy_batch[input_key] assert x is not None channel_adapter = self.channel_adapter_config.build( x.shape[1], input_channel_names=input_channel_names ) with torch.no_grad(): model.eval() if channel_adapter is not None: input_key = next(iter(dummy_batch)) x = dummy_batch[input_key] assert x is not None if adapter_needs_positions: subject_ids = dummy_batch.get( "subject_ids", torch.zeros(x.shape[0], dtype=torch.long), ) ch_pos = dummy_batch.get("channel_positions") x_adapted = channel_adapter(x, subject_ids, ch_pos) else: x_adapted = channel_adapter(x) model_batch = {input_key: x_adapted} orig_output = model(**model_batch) else: orig_output = model(**dummy_batch) if self.model_output_key is not None: orig_output = orig_output[self.model_output_key] model.train() wrapper_model = DownstreamWrapperModel( model, orig_output.shape[1:], preprocessor=preprocessor, channel_adapter=channel_adapter, adapter_needs_positions=adapter_needs_positions, model_output_key=self.model_output_key, wrapper_n_outputs=n_outputs, layers_to_freeze=self.layers_to_freeze, layers_to_unfreeze=self.layers_to_unfreeze, strict_matching=self.strict_matching, aggregation=self.aggregation, probe_config=self.probe_config, ) # Sanity check (wrapper handles preprocessing internally) wrapper_output = wrapper_model(**dummy_batch) assert wrapper_output.shape[-1] == n_outputs return wrapper_model
[docs] class DownstreamWrapperModel(nn.Module): """Wrapper for downstream evaluation of pretrained models. Handles the full pipeline: optional preprocessing -> channel adapter -> model -> output key selection -> aggregation -> probe. """ def __init__( self, model: nn.Module, brain_model_output_size: torch.Size, model_output_key: str | int | None, wrapper_n_outputs: int, preprocessor: nn.Module | None = None, channel_adapter: nn.Module | None = None, adapter_needs_positions: bool = False, layers_to_freeze: list[str] | None = None, layers_to_unfreeze: list[str] | tp.Literal["last"] | None = None, strict_matching: bool = True, aggregation: tp.Literal["flatten", "mean", "first"] | int | None = "flatten", probe_config: Mlp | tp.Literal["linear"] | None = None, ): super().__init__() self.preprocessor = preprocessor self.channel_adapter = channel_adapter self._adapter_needs_positions = adapter_needs_positions self.wrapped_model = model self.model_output_key = model_output_key inner_sig = inspect.signature(model.forward) self._inner_param_names = set(inner_sig.parameters.keys()) self._inner_accepts_var_kwargs = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in inner_sig.parameters.values() ) self._apply_freeze(layers_to_freeze, layers_to_unfreeze, strict_matching) n_inputs = self._build_aggregation(aggregation, brain_model_output_size) self._build_probe(probe_config, n_inputs, wrapper_n_outputs) def _apply_freeze( self, layers_to_freeze: list[str] | None, layers_to_unfreeze: list[str] | tp.Literal["last"] | None, strict_matching: bool, ) -> None: """Freeze or unfreeze model parameters based on layer name patterns.""" if layers_to_freeze is not None: for name, param in self.wrapped_model.named_parameters(): if strict_matching: requires_grad = name.split(".")[0] not in layers_to_freeze else: requires_grad = not any( pattern in name for pattern in layers_to_freeze ) param.requires_grad = requires_grad elif layers_to_unfreeze == "last": param_names = list(self.wrapped_model.named_parameters()) last_layer_name = param_names[-1][0].rsplit(".", 1)[0] unfrozen_layers = [] for name, param in self.wrapped_model.named_parameters(): if name.startswith(last_layer_name): unfrozen_layers.append(name) param.requires_grad = True else: param.requires_grad = False LOGGER.warning(f"Unfreezing {unfrozen_layers}") elif layers_to_unfreeze is not None: for name, param in self.wrapped_model.named_parameters(): if strict_matching: requires_grad = name.split(".")[0] in layers_to_unfreeze else: requires_grad = any(pattern in name for pattern in layers_to_unfreeze) param.requires_grad = requires_grad def _build_aggregation( self, aggregation: tp.Literal["flatten", "mean", "first"] | int | None, brain_model_output_size: torch.Size, ) -> int: """Build the aggregation module and return the flattened input size for the probe.""" self.aggregation: nn.Module if aggregation is None: self.aggregation = nn.Identity() return brain_model_output_size.numel() elif aggregation == "flatten": self.aggregation = nn.Flatten(start_dim=1) return brain_model_output_size.numel() elif aggregation == "first": assert len(brain_model_output_size) == 2 self.aggregation = IndexSelect(dim=1, index=torch.LongTensor([0])) return brain_model_output_size[1] elif aggregation == "mean": dim: int | tuple[int, ...] = 1 if len(brain_model_output_size) == 2: # (n_patches, emb_dim) dim = 1 elif len(brain_model_output_size) == 3: # (n_chans, n_patches, emb_dim) dim = (1, 2) else: raise ValueError( f"aggregation='mean' requires model output of 3D or 4D " f"(got brain_model_output_size={brain_model_output_size})" ) self.aggregation = Mean(dim=dim) return brain_model_output_size[-1] elif isinstance(aggregation, int): assert len(brain_model_output_size) == 2 self.aggregation = ConcatGroupedMean(dim=1, n_splits=aggregation) return aggregation * brain_model_output_size[1] else: raise NotImplementedError() def _build_probe( self, probe_config: Mlp | tp.Literal["linear"] | None, n_inputs: int, n_outputs: int, ) -> None: """Build the probe (classification/regression head) on top of the aggregated representations.""" self.probe: nn.Module if probe_config is None: self.probe = nn.Identity() elif probe_config == "linear": self.probe = nn.Linear(n_inputs, n_outputs) else: assert not isinstance(probe_config, str) self.probe = probe_config.build( input_size=n_inputs, output_size=n_outputs, )
[docs] def forward(self, *args, return_embedding: bool = False, **kwargs) -> torch.Tensor: if self.preprocessor is not None: if args: x, *rest_args = args x, ch_pos = self.preprocessor(x, kwargs.get("channel_positions")) args = (x, *rest_args) else: input_key = next(iter(kwargs)) x = kwargs[input_key] x, ch_pos = self.preprocessor(x, kwargs.get("channel_positions")) kwargs = {**kwargs, input_key: x} if ch_pos is not None and "channel_positions" in kwargs: kwargs["channel_positions"] = ch_pos if self.channel_adapter is not None: input_key = next(iter(kwargs)) x = kwargs[input_key] if self._adapter_needs_positions: subject_ids = kwargs.get( "subject_ids", torch.zeros(x.shape[0], dtype=torch.long, device=x.device), ) ch_pos = kwargs.get("channel_positions") x = self.channel_adapter(x, subject_ids, ch_pos) else: x = self.channel_adapter(x) kwargs = {**kwargs, input_key: x} if not self._inner_accepts_var_kwargs: kwargs = {k: v for k, v in kwargs.items() if k in self._inner_param_names} out = self.wrapped_model(*args, **kwargs) if self.model_output_key is not None: out = out[self.model_output_key] out = self.aggregation(out) if return_embedding: return out out = self.probe(out) return out