# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Common modules to be used with brain models."""
# pylint: disable=unused-variable
import math
import typing as tp
from collections import deque
import torch
from torch import nn
from torchvision.ops import MLP
from neuraltrain.models.base import BaseModelConfig
INVALID_POS_VALUE = -0.1 # See ns.extractors.ChannelPositions.INVALID_VALUE
def parse_bipolar_name(name: str) -> tuple[str, str] | None:
"""Split a bipolar channel name into its anode and cathode.
Returns ``("Fp1", "F3")`` for ``"Fp1-F3"``, or ``None`` when *name*
does not look like a bipolar pair (i.e. does not contain exactly one
hyphen separating two non-empty parts).
"""
parts = name.split("-")
if len(parts) == 2 and parts[0] and parts[1]:
return parts[0], parts[1]
return None
def compute_temporal_adjustment(n_times: int, patch_size: int) -> tuple[int, int]:
"""Compute how to align *n_times* to a multiple of *patch_size*.
Returns ``(pad_right, truncate_right)`` -- at most one is non-zero.
* ``n_times < patch_size`` -- pad to one full patch.
* ``n_times`` not divisible by ``patch_size`` -- truncate the remainder.
* Already aligned -- ``(0, 0)``.
"""
if n_times < patch_size:
return patch_size - n_times, 0
remainder = n_times % patch_size
if remainder != 0:
return 0, remainder
return 0, 0
def apply_temporal_adjustment(
x: torch.Tensor,
pad_right: int = 0,
truncate_right: int = 0,
) -> torch.Tensor:
"""Zero-pad or truncate the last (time) dimension of *x*."""
if pad_right > 0:
return torch.nn.functional.pad(x, (0, pad_right))
if truncate_right > 0:
return x[..., :-truncate_right]
return x
[docs]
class BahdanauAttention(nn.Module):
"""Bahdanau attention from [1]_.
Implementation inspired from pytorch's seq2seq tutorial:
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#the-decoder
.. [1] Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. "Neural machine translation by
jointly learning to align and translate." arXiv preprint arXiv:1409.0473 (2014).
"""
def __init__(self, input_size, hidden_size):
super().__init__()
if input_size is None:
self.Wa = nn.LazyLinear(hidden_size)
self.Ua = nn.LazyLinear(hidden_size)
else:
self.Wa = nn.Linear(input_size, hidden_size)
self.Ua = nn.Linear(input_size, hidden_size)
self.Va = nn.Linear(hidden_size, 1)
[docs]
def forward(self, keys, queries=None):
"""
Parameters
----------
keys:
Key tensor of shape (batch_size, n_features, n_times).
queries:
Optional query tensor of shape (batch_size, n_features, n_times).
If None, only keys are used.
"""
keys = keys.transpose(2, 1) # (B, F, T) -> (B, T, F)
sum_ = self.Wa(keys)
if queries is not None:
queries = queries.transpose(2, 1)
assert queries.shape == keys.shape
sum_ += self.Ua(queries)
scores = self.Va(torch.tanh(sum_))
scores = scores.squeeze(2).unsqueeze(1)
weights = nn.functional.softmax(scores, dim=-1)
context = torch.bmm(weights, keys)
context = context.transpose(2, 1) # (B, 1, F) -> (B, F, 1)
return context
[docs]
class ChannelDropout(nn.Module):
def __init__(self):
super().__init__()
raise NotImplementedError("See brainmagick.models.common.")
[docs]
def forward(self, x):
raise NotImplementedError
[docs]
class SubjectLayers(BaseModelConfig):
"""Configuration for per-subject linear projections.
Parameters
----------
n_subjects : int
Number of subjects to allocate weight matrices for.
bias : bool
Include a bias term in each subject's projection.
init_id : bool
Initialize projection matrices to the identity (requires
``in_channels == out_channels``).
mode : {"gather", "for_loop"}
``"gather"`` builds a ``(B, C_in, C_out)`` tensor via index_select
(fast but memory-heavy for large channel counts). ``"for_loop"``
iterates over unique subjects (slower but lighter).
subject_dropout : float or None
Probability of replacing a subject index with a shared "dropout
subject" during training. Required when *average_subjects* is True.
average_subjects : bool
At inference time, use the shared dropout-subject weights for all
examples. Requires *subject_dropout* to be set.
"""
n_subjects: int = 200
bias: bool = True
init_id: bool = False
mode: tp.Literal["gather", "for_loop"] = "gather"
subject_dropout: float | None = None
average_subjects: bool = False
def build(self, in_channels: int, out_channels: int) -> nn.Module:
kwargs = self.model_dump()
del kwargs["name"]
return SubjectLayersModel(in_channels, out_channels, **kwargs)
class SubjectLayersModel(nn.Module):
"""Per subject linear projection.
Parameters
----------
in_channels :
Number of input channels.
out_channels :
Number of output channels.
n_subjects :
Number of subjects to initialize weights for.
bias:
If True, use a bias term.
init_id :
If True, initialize the projection matrices with the identity.
mode :
How to apply the linear projection. With "gather" (original implementation), a tensor of
shape (batch_size, in_channels, out_channels) containing the projection matrices for each
example in the batch is first created. This tensor can be very large when the number of
channels is high (e.g. when using on fMRI data with many input voxels). In this case, it
may be better to use "for_loop": this will loop over each unique subject in the batch to
apply the projection separately.
subject_dropout :
If not None, probability with which a subject's index is replaced by a shared "dropout
subject" during training. An extra row is added to the weight matrix for this purpose.
Required when ``average_subjects`` is True.
average_subjects :
If True, use the shared dropout-subject weights for all examples (inference shortcut).
Requires ``subject_dropout`` to be set.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
*,
n_subjects: int = 200,
bias: bool = True,
init_id: bool = False,
mode: tp.Literal["gather", "for_loop"] = "gather",
subject_dropout: float | None = None,
average_subjects: bool = False,
):
super().__init__()
self.n_subjects = n_subjects
self.subject_dropout = subject_dropout
num_weight_subjects = n_subjects + 1 if subject_dropout else n_subjects
self.weights = nn.Parameter(
torch.empty(num_weight_subjects, in_channels, out_channels)
)
self.bias = (
nn.Parameter(torch.empty(num_weight_subjects, out_channels)) if bias else None
)
if init_id:
if in_channels != out_channels:
raise ValueError(
"in_channels and out_channels must be the same for identity initialization."
)
self.weights.data[:] = torch.eye(in_channels)[None]
if self.bias is not None:
self.bias.data[:] = 0
else:
self.weights.data.normal_()
if self.bias is not None:
self.bias.data.normal_()
self.weights.data *= 1 / in_channels**0.5
if self.bias is not None:
self.bias.data *= 1 / in_channels**0.5
self.average_subjects = average_subjects
self.mode = mode
def forward(
self,
x: torch.Tensor, # (batch_size, in_channels) or (batch_size, in_channels, n_times)
subjects: torch.Tensor, # (batch_size,)
) -> (
torch.Tensor
): # (batch_size, out_channels) or (batch_size, out_channels, n_times)
if x.ndim not in [2, 3]:
raise ValueError(f"Expected shape (B, C, T) or (B, C), got {x.shape}")
if x.ndim == 2:
has_time_dimension = False
x = x.unsqueeze(2)
else:
has_time_dimension = True
B, C, T = x.shape
N, C, D = self.weights.shape
if self.average_subjects:
if not self.subject_dropout:
raise ValueError("subject_dropout must be set to average subjects.")
weights = self.weights[self.n_subjects]
out = torch.einsum("bct,cd->bdt", x, weights)
if self.bias is not None:
out += self.bias[self.n_subjects].view(1, D, 1)
return out if has_time_dimension else out.squeeze(2)
else:
if self.training and self.subject_dropout:
subject_dropout_mask = (
torch.rand(subjects.shape, device=subjects.device)
< self.subject_dropout
)
subjects = subjects.clone()
subjects[subject_dropout_mask] = self.n_subjects
if subjects.max() >= N:
raise ValueError(
f"Subject index {subjects.max()} out of range for {N} weight slots"
)
if self.mode == "gather":
weights = self.weights.index_select(0, subjects.flatten())
out = torch.einsum("bct,bcd->bdt", x, weights)
if self.bias is not None:
out += self.bias.index_select(0, subjects.flatten()).view(B, D, 1)
else:
out = torch.empty((B, D, T), device=x.device)
for subject in subjects.unique():
mask = subjects.reshape(-1) == subject
out[mask] = torch.einsum(
"bct,cd->bdt", x[mask], self.weights[subject]
)
if self.bias is not None:
out[mask] += self.bias[subject].view(1, D, 1)
return out if has_time_dimension else out.squeeze(2)
def __repr__(self):
S, C, D = self.weights.shape
return f"SubjectLayersModel({C}, {D}, {S})"
[docs]
class FourierEmb(BaseModelConfig):
"""Configuration for Fourier positional embedding.
Parameters
----------
n_freqs :
Number of frequencies (harmonics) used to encode **one** dimension.
total_dim :
If provided instead of `n_freqs`, this will be used to compute the number of
frequencies following this relationship:
n_freqs = (total_dim / 2) ** (1 / n_dims)
If the resulting `n_freqs` is not an integer an exception will be raised.
n_dims :
Number of dimensions to embed. This should be 2 for 2D positions (e.g. MNE layouts) or 3
for 3D positions (e.g. MNE montages).
margin :
How much to extend the range of the embedding to avoid edge effects.
"""
n_freqs: int | None = 12
total_dim: int | None = None
n_dims: int = 2
margin: float = 0.2
def build(self) -> "FourierEmbModel":
if self.total_dim is not None and self.n_freqs is None:
n_freqs = (self.total_dim / 2) ** (1 / self.n_dims)
if abs(n_freqs - round(n_freqs)) > 1e-6: # Check if n_freqs is integer
raise ValueError("(total_dim / 2) ** (1 / n_dims) must be an integer.")
n_freqs = round(n_freqs)
elif self.n_freqs is not None and self.total_dim is None:
n_freqs = self.n_freqs
else:
raise ValueError("Exactly one of n_freqs and total_dim must be provided.")
return FourierEmbModel(
n_freqs=n_freqs,
n_dims=self.n_dims,
margin=self.margin,
)
class FourierEmbModel(nn.Module):
"""Fourier positional embedding.
Unlike traditional embedding this is not using exponential periods for cosines and sinuses, but
typical `2 pi k` which can represent any function over [0, 1]. As this function would be
necessarily periodic, we take a bit of margin and do over e.g. [-0.2, 1.2].
"""
def __init__(
self,
n_freqs: int,
n_dims: int,
margin: float,
):
super().__init__()
self.n_freqs = n_freqs
self.n_dims = n_dims
self.margin = margin
# Precompute sin/cos arguments
freqs = torch.arange(n_freqs)
width = 1 + 2 * self.margin
pos = 2 * math.pi * freqs / width
self.register_buffer("pos", pos)
@property
def total_dim(self) -> int:
"""Total dimension of the embedding."""
return (self.n_freqs**self.n_dims) * 2
@staticmethod
def _outer_sum(x: torch.Tensor) -> torch.Tensor:
"""Outer sum between the last dimensions of `x`.
x.shape[-1] is expected to match the dimensions of the grid, between which the outer sum
is computed. For example, if x.shape[-1] == 2, the outer sum will be computed between the
last two dimensions of `x`, i.e. x[..., 0] and x[..., 1].
"""
inds = deque([slice(None)] + [None] * (x.shape[-1] - 1))
out = x[..., 0][(...,) + tuple(inds)]
for i in range(1, x.shape[-1]):
inds.rotate()
out = out + x[..., i][(...,) + tuple(inds)]
return out
def forward(self, positions: torch.Tensor) -> torch.Tensor:
"""Compute Fourier embedding for the given positions.
Parameters
----------
positions : Tensor
Coordinates of shape ``(..., D)`` where ``D`` equals *n_dims*.
"""
*O, D = positions.shape
if D != self.n_dims:
raise ValueError(f"Expected {self.n_dims} positions, got {D}.")
positions = positions + self.margin
locs = torch.einsum("bcd,f->bcfd", positions, self.pos)
loc_grid = self._outer_sum(locs).view(*O, -1)
emb = torch.cat(
[
torch.cos(loc_grid),
torch.sin(loc_grid),
],
dim=-1,
)
return emb
[docs]
class ChannelMerger(BaseModelConfig):
"""Configuration for the ChannelMerger module.
Parameters
----------
embed_ref :
Also embed the reference position, e.g. to enable handling bipolar channels. This requires
passing both `positions` and `ref_positions` to `forward()`.
dropout_around_channel :
If True, randomly sample a channel to apply dropout around. If False, randomly sample a
point in [0, 1] ^ D, where D is the number of dimensions (2 or 3), around which to apply
dropout.
unmerge :
If True, unmerge (rather than merge) channels. This is useful to compute the inverse
operation of a default `ChannelMerger`. In this case, the input to `forward()` should be of
shape (B, n_virtual_channels, T).
invalid_value :
If all position dimensions for a channel are equal to `invalid_value`, the channel will be
masked out. This is useful when examples within a batch contain different channels and
therefore some channels need to be ignored for some of the examples.
NOTE: `ns.extractors.ChannelPositions` defines this value as well.
"""
n_virtual_channels: int = 270
fourier_emb_config: FourierEmb = FourierEmb(
n_freqs=None,
total_dim=288,
n_dims=2,
)
dropout: float = 0
dropout_around_channel: bool = False
usage_penalty: float = 0.0
n_subjects: int = 200
per_subject: bool = False
embed_ref: bool = False
unmerge: bool = False
invalid_value: float = INVALID_POS_VALUE
def build(self) -> "ChannelMergerModel":
return ChannelMergerModel(self)
class ChannelMergerModel(nn.Module):
"""``nn.Module`` implementation of :class:`ChannelMerger`.
Merges (or unmerges) channels via Fourier-embedding-based spatial attention.
"""
def __init__(self, config: ChannelMerger = ChannelMerger()):
super().__init__()
self.embedding = config.fourier_emb_config.build()
self.n_dims = self.embedding.n_dims
pos_dim = self.embedding.total_dim
assert isinstance(pos_dim, int) # for mypy
self.per_subject = config.per_subject
self.embed_ref = config.embed_ref
n_params_pos_dim = pos_dim * 2 if self.embed_ref else pos_dim
if self.per_subject:
self.heads = nn.Parameter(
torch.randn(
config.n_subjects, config.n_virtual_channels, n_params_pos_dim
)
)
else:
self.heads = nn.Parameter(
torch.randn(config.n_virtual_channels, n_params_pos_dim)
)
self.invalid_value = config.invalid_value
self.heads.data /= pos_dim**0.5 # XXX Double check
self.dropout = config.dropout
self.dropout_around_channel = config.dropout_around_channel
self.usage_penalty = config.usage_penalty
self._penalty = torch.tensor(0.0)
self.unmerge = config.unmerge
@property
def training_penalty(self):
return self._penalty.to(next(self.parameters()).device)
def _get_weights(
self,
subject_ids: torch.Tensor,
positions: torch.Tensor,
device: torch.device,
) -> torch.Tensor:
B, C, _ = positions.shape
if self.embed_ref:
if positions.shape[2] == self.n_dims:
ref_pad = torch.full(
(B, C, self.n_dims),
self.invalid_value,
dtype=positions.dtype,
device=positions.device,
)
positions = torch.cat([positions, ref_pad], dim=2)
if positions.shape[2] != self.n_dims * 2: # type: ignore
got = positions.shape[2]
raise ValueError(f"embed_ref needs {self.n_dims * 2} dims, got {got}")
embedding = torch.cat(
[
self.embedding(positions[..., : self.n_dims]), # type: ignore
self.embedding(positions[..., self.n_dims :]), # type: ignore
],
dim=2,
)
else:
if positions.shape[2] != self.n_dims:
raise ValueError(
f"Expected {self.n_dims} spatial dimensions, got {positions.shape[2]}"
)
embedding = self.embedding(positions)
score_offset = torch.zeros(B, C, device=device)
invalid_mask = (positions == self.invalid_value).all(dim=-1)
score_offset = score_offset.masked_fill(invalid_mask, float("-inf"))
if self.training and self.dropout:
if self.unmerge:
raise NotImplementedError(
"Figure out how to apply dropout if unmerge=True"
)
if self.dropout_around_channel:
all_valid_positions = positions[~invalid_mask]
ind = int(torch.randint(0, all_valid_positions.shape[0], (1,))[0])
center_to_ban = all_valid_positions[ind, : self.n_dims].to(device)
else:
center_to_ban = torch.rand(self.n_dims, device=device)
radius_to_ban = self.dropout
banned = (positions[:, :, : self.n_dims] - center_to_ban).norm(
dim=-1
) <= radius_to_ban
score_offset = score_offset.masked_fill(banned, float("-inf"))
if self.per_subject:
_, cout, pos_dim = self.heads.shape
heads = self.heads.gather(
0, subject_ids.view(-1, 1, 1).expand(-1, cout, pos_dim)
)
else:
heads = self.heads[None].expand(B, -1, -1)
scores = torch.einsum("bcd,bod->boc", embedding, heads)
scores += score_offset[:, None]
if self.unmerge:
scores = scores.transpose(1, 2)
return torch.softmax(scores, dim=2).nan_to_num() # Replace nans by 0
def forward(
self,
meg: torch.Tensor,
subject_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
"""Apply spatial attention on input.
Parameters
----------
positions :
Normalized (x, y) coordinates for each channel in `meg`, of shape (B, C, 2).
If shape is (B, C, 4), the additional two coordinates per channel indicate the
position of the reference electrode. See `ns.extractors.ChannelPositions`.
"""
weights = self._get_weights(subject_ids, positions, meg.device)
out = weights @ meg
if self.training and self.usage_penalty > 0.0:
usage = weights.mean(dim=(0, 1)).sum()
self._penalty = torch.tensor(self.usage_penalty * usage, device=meg.device)
return out
[docs]
class LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonaly residual outputs close to 0 initially, then learnt.
"""
def __init__(self, channels: int, init: float = 0.1, boost: float = 5.0):
super().__init__()
self.scale = nn.Parameter(torch.zeros(channels))
self.scale.data[:] = init / boost
self.boost = boost
[docs]
def forward(self, x):
return (self.boost * self.scale[:, None]) * x
class UnitNorm(nn.Module):
"""Normalize last dimension of tensor to have unit Frobenius norm.
Useful for parametrizing different normalization alternatives in `Mlp` below.
NOTE: `hidden_dim` argument included for consistency with other normalization layers (e.g.
BatchNorm).
"""
def __init__(self, hidden_dim: int = 0) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x / x.norm(p="fro", dim=-1, keepdim=True)
class Mean(nn.Module):
def __init__(self, dim: int, keepdim: bool = False):
super().__init__()
self.dim = dim
self.keepdim = keepdim
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mean(dim=self.dim, keepdim=self.keepdim)
[docs]
class NormDenormScaler(nn.Module):
"""Norm-denorm scaler inspired by [1]_.
At inference time, this module applies z-score normalization of its input, followed by
de-normalization based on the statistics of the data seen at instantiation.
Parameters
----------
x :
Data on which to fit the denormalizer, of shape (n_examples, n_features).
affine :
If True, de-normalize with the statistics of `x`.
References
----------
.. [1] Ozcelik, Furkan, and Rufin VanRullen. "Natural scene reconstruction from fMRI signals
using generative latent diffusion." Scientific Reports 13.1 (2023): 15666.
"""
def __init__(self, x: torch.Tensor, affine: bool = True):
super().__init__()
if x.ndim != 2:
raise ValueError(f"Tensor must be 2D (flattened), got ndim={x.ndim}")
self.scaler = nn.BatchNorm1d(
x.shape[1], affine=affine, track_running_stats=False, eps=1e-15
).eval()
if affine:
# Disable gradient as this is not currently intended to be finetuned
self.scaler.weight.requires_grad = False
self.scaler.bias.requires_grad = False
self.scaler.weight.data = x.std(dim=0, correction=0)
self.scaler.bias.data = x.mean(dim=0)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.scaler(x)
[docs]
class Mlp(BaseModelConfig):
"""Multilayer perceptron, e.g. for use as projection head.
Notes
-----
Input size can be specified in the config or at build time.
Output size can be specified at build time through the `output_size` parameter: it will be
appended to `hidden_sizes` as the final layer.
When `hidden_sizes` is empty or None and `output_size` is given, `build()` returns a
single ``nn.Linear`` layer. When both are absent it returns ``nn.Identity``.
"""
input_size: int | None = None
hidden_sizes: list[int] | None = None
norm_layer: tp.Literal["layer", "batch", "instance", "unit", None] = None
activation_layer: tp.Literal["relu", "gelu", "elu", "prelu", None] = "relu"
bias: bool = True
dropout: float = 0.0
@staticmethod
def _get_norm_layer(kind: str | None) -> tp.Type[nn.Module] | None:
return {
"batch": nn.BatchNorm1d,
"layer": nn.LayerNorm,
"instance": nn.InstanceNorm1d,
"unit": UnitNorm,
None: None,
}[kind]
@staticmethod
def _get_activation_layer(kind: str | None) -> tp.Type[nn.Module]:
return {
"gelu": nn.GELU,
"relu": nn.ReLU,
"elu": nn.ELU,
"prelu": nn.PReLU,
None: nn.Identity,
}[kind]
def build(
self, input_size: int | None = None, output_size: int | None = None
) -> nn.Sequential | nn.Linear | nn.Identity:
input_size = self.input_size if input_size is None else input_size
if input_size is None:
raise ValueError("input_size cannot be None.")
if not self.hidden_sizes:
if output_size is None:
return nn.Identity()
return nn.Linear(input_size, output_size)
hidden_sizes = self.hidden_sizes.copy()
if output_size is not None:
hidden_sizes.append(output_size)
return MLP(
in_channels=input_size,
hidden_channels=hidden_sizes,
norm_layer=self._get_norm_layer(self.norm_layer),
activation_layer=self._get_activation_layer(self.activation_layer),
bias=self.bias,
dropout=self.dropout,
)
[docs]
class TemporalDownsampling(BaseModelConfig):
"""Temporal downsampling via a 2-D convolution over the time axis.
Parameters
----------
kernel_size : int
Kernel height (time-axis) of the downsampling convolution.
stride : int
Stride along the time axis.
layer_norm : bool
Apply ``LayerNorm`` after the convolution.
layer_norm_affine : bool
Use learnable affine parameters in the ``LayerNorm``.
gelu : bool
Apply GELU activation after normalization.
"""
kernel_size: int = 45
stride: int = 45
layer_norm: bool = True
layer_norm_affine: bool = True
gelu: bool = True
def build(self, dim: int) -> nn.Module:
return TemporalDownsamplingModel(dim=dim, config=self)
class TemporalDownsamplingModel(nn.Module):
"""``nn.Module`` implementation of :class:`TemporalDownsampling`."""
def __init__(
self,
dim: int,
config: TemporalDownsampling | None = None,
) -> None:
config = config if config is not None else TemporalDownsampling()
super().__init__()
self.agg = nn.Conv2d(
1, 1, kernel_size=(config.kernel_size, 1), stride=(config.stride, 1)
)
self.layer_norm = (
nn.LayerNorm(dim, eps=1e-8, elementwise_affine=config.layer_norm_affine)
if config.layer_norm
else None
)
self.gelu = nn.GELU() if config.gelu else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x.shape expected to be (B, 1, T, F)
if x.ndim != 4:
raise ValueError(f"Input must be of shape (B, 1, T, F), got {x.shape}")
x = self.agg(x)
if self.layer_norm:
x = self.layer_norm(x)
if self.gelu:
x = self.gelu(x)
return x