# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""SimpleConv, taken and slightly modified from brainmagick."""
import logging
import typing as tp
from functools import partial
import torch
from torch import nn
from .base import BaseModelConfig
from .common import (
BahdanauAttention,
ChannelMerger,
FourierEmb,
LayerScale,
Mlp,
SubjectLayers,
)
from .transformer import TransformerEncoder
logger = logging.getLogger(__name__)
class ConvSequence(nn.Module):
def __init__(
self,
channels: tp.Sequence[int],
kernel: int = 4,
dilation_growth: int = 1,
dilation_period: int | None = None,
stride: int = 2,
dropout: float = 0.0,
leakiness: float = 0.0,
groups: int = 1,
decode: bool = False,
batch_norm: bool = False,
dropout_input: float = 0.0,
skip: bool = False,
scale: float | None = None,
rewrite: bool = False,
activation_on_last: bool = True,
post_skip: bool = False,
glu: int = 0,
glu_context: int = 0,
glu_glu: bool = True,
activation: tp.Any = None,
) -> None:
super().__init__()
dilation = 1
channels = tuple(channels)
self.skip = skip
self.sequence = nn.ModuleList()
self.glus = nn.ModuleList()
if activation is None:
activation = partial(nn.LeakyReLU, leakiness)
Conv = nn.Conv1d if not decode else nn.ConvTranspose1d
# build layers
for k, (chin, chout) in enumerate(zip(channels[:-1], channels[1:])):
layers: tp.List[nn.Module] = []
is_last = k == len(channels) - 2
# Set dropout for the input of the conv sequence if defined
if k == 0 and dropout_input:
if not (0.0 < dropout_input < 1.0):
raise ValueError(f"{dropout_input=} must be in (0,1)")
layers.append(nn.Dropout(dropout_input))
# conv layer
if dilation_growth > 1:
if kernel % 2 == 0:
raise ValueError(f"Odd kernel required with dilation, got {kernel}")
if dilation_period and (k % dilation_period) == 0:
dilation = 1
pad = kernel // 2 * dilation
layers.append(
Conv(
chin,
chout,
kernel,
stride,
pad,
dilation=dilation,
groups=groups if k > 0 else 1,
)
)
dilation *= dilation_growth
# non-linearity
if activation_on_last or not is_last:
if batch_norm:
layers.append(nn.BatchNorm1d(num_features=chout))
layers.append(activation())
if dropout:
layers.append(nn.Dropout(dropout))
if rewrite:
layers += [nn.Conv1d(chout, chout, 1), nn.LeakyReLU(leakiness)]
# layers += [nn.Conv1d(chout, 2 * chout, 1), nn.GLU(dim=1)]
if chin == chout and skip:
if scale is not None:
layers.append(LayerScale(chout, scale))
if post_skip:
layers.append(Conv(chout, chout, 1, groups=chout, bias=False))
self.sequence.append(nn.Sequential(*layers))
if glu and (k + 1) % glu == 0:
ch = 2 * chout if glu_glu else chout
act = nn.GLU(dim=1) if glu_glu else activation()
self.glus.append(
nn.Sequential(
nn.Conv1d(chout, ch, 1 + 2 * glu_context, padding=glu_context),
act,
)
)
else:
self.glus.append(None) # type: ignore
def forward(self, x: tp.Any) -> tp.Any:
for module_idx, module in enumerate(self.sequence):
old_x = x
x = module(x)
if self.skip and x.shape == old_x.shape:
x = x + old_x
glu = self.glus[module_idx]
if glu is not None:
x = glu(x)
return x
[docs]
class SimpleConv(BaseModelConfig):
"""1-D convolutional encoder, adapted from brainmagick.
Parameters
----------
hidden : int
Number of channels in the first convolutional layer.
depth : int
Number of convolutional layers.
linear_out : bool
Use a single transposed convolution as the output projection.
complex_out : bool
Use a two-layer transposed-convolution output projection with a
non-linearity in between. Mutually exclusive with *linear_out*.
kernel_size : int
Kernel size for every convolutional layer (must be odd).
growth : float
Multiplicative channel growth factor per layer.
dilation_growth : int
Multiplicative dilation growth factor per layer.
dilation_period : int or None
If set, reset dilation to 1 every *dilation_period* layers.
skip : bool
Add residual skip connections when input and output shapes match.
post_skip : bool
Append a depth-wise convolution after each skip connection.
scale : float or None
If set, apply :class:`LayerScale` with this initial value after
each skip connection.
rewrite : bool
Append a 1x1 convolution + LeakyReLU after each layer.
groups : int
Number of groups for grouped convolutions (first layer always uses 1).
glu : int
If non-zero, insert a GLU gate every *glu* layers.
glu_context : int
Context (padding) size for the GLU convolution.
glu_glu : bool
If True the gate uses ``nn.GLU``; otherwise the layer activation.
gelu : bool
Use GELU activation instead of (Leaky)ReLU.
dropout : float
Channel-dropout probability (currently raises ``NotImplementedError``).
dropout_rescale : bool
Rescale activations after channel dropout.
conv_dropout : float
Dropout probability inside each convolutional block.
dropout_input : float
Dropout probability applied to the input of the convolutional stack.
batch_norm : bool
Apply batch normalization after each convolution.
relu_leakiness : float
Negative slope for ``LeakyReLU`` (0 gives standard ReLU).
transformer_config : TransformerEncoder or None
If set, append a Transformer encoder after the convolutional stack.
subject_layers_config : SubjectLayers or None
If set, prepend a per-subject linear projection.
subject_layers_dim : {"input", "hidden"}
Dimension used for the subject-layer projection.
merger_config : ChannelMerger or None
If set, prepend a :class:`ChannelMerger` for multi-montage support.
initial_linear : int
If non-zero, prepend a 1x1 convolution projecting to this many channels.
initial_depth : int
Number of 1x1 convolution layers in the initial projection.
initial_nonlin : bool
Append a non-linearity after the initial 1x1 projection stack.
backbone_out_channels : int or None
If set, force the backbone output to this many channels.
"""
# Channels
hidden: int = 16
# Overall structure
depth: int = 4
linear_out: bool = False
complex_out: bool = False
# Conv layer
kernel_size: int = 5
growth: float = 1.0
dilation_growth: int = 2
dilation_period: int | None = None
skip: bool = False
post_skip: bool = False
scale: float | None = None
rewrite: bool = False
groups: int = 1
glu: int = 0
glu_context: int = 0
glu_glu: bool = True
gelu: bool = False
# Dropouts, BN, activations
dropout: float = 0.0
dropout_rescale: bool = True
conv_dropout: float = 0.0
dropout_input: float = 0.0
batch_norm: bool = False
relu_leakiness: float = 0.0
# Optional transformer
transformer_config: TransformerEncoder | None = None
# Subject-specific settings
subject_layers_config: SubjectLayers | None = None
subject_layers_dim: tp.Literal["input", "hidden"] = "hidden"
# Channel attention for multi-montage support
merger_config: ChannelMerger | None = ChannelMerger(
n_virtual_channels=270,
fourier_emb_config=FourierEmb(
n_freqs=None,
total_dim=2048,
n_dims=2,
),
dropout=0.2,
usage_penalty=0.0,
per_subject=False,
embed_ref=False,
)
# Architectural details
initial_linear: int = 0
initial_depth: int = 1
initial_nonlin: bool = False
backbone_out_channels: int | None = None # If provided, the output of the
# backbone (i.e. layer before the output heads) will have this dimensionality
def build(self, n_in_channels: int, n_outputs: int) -> "SimpleConvModel":
return SimpleConvModel(n_in_channels, n_outputs, config=self)
[docs]
class SimpleConvModel(nn.Module):
"""``nn.Module`` implementation of :class:`SimpleConv`."""
def __init__(
self,
# Channels
in_channels: int,
out_channels: int,
config: SimpleConv | None = None,
):
super().__init__()
config = config if config is not None else SimpleConv()
self.out_channels = out_channels
self.backbone_out_channels = config.backbone_out_channels or out_channels
activation: nn.Module | tp.Callable
if config.gelu:
activation = nn.GELU
elif config.relu_leakiness:
activation = partial(nn.LeakyReLU, config.relu_leakiness)
else:
activation = nn.ReLU
if config.kernel_size % 2 != 1:
raise ValueError(f"{config.kernel_size=}, must be odd for padding")
self.dropout = None
self.initial_linear = None
if config.dropout > 0.0:
raise NotImplementedError("To be reimplemented here.")
# self.dropout = ChannelDropout(dropout, dropout_rescale)
self.merger = None
if config.merger_config is not None:
self.merger = config.merger_config.build()
in_channels = config.merger_config.n_virtual_channels
if config.initial_linear:
init: list[nn.Module | tp.Callable] = [
nn.Conv1d(in_channels, config.initial_linear, 1)
]
for _ in range(config.initial_depth - 1):
init += [
activation(),
nn.Conv1d(config.initial_linear, config.initial_linear, 1),
]
if config.initial_nonlin:
init += [activation()]
self.initial_linear = nn.Sequential(*init) # type: ignore[arg-type]
in_channels = config.initial_linear
self.subject_layers = None
if config.subject_layers_config:
dim = {"hidden": config.hidden, "input": in_channels}[
config.subject_layers_dim
]
self.subject_layers = config.subject_layers_config.build(in_channels, dim)
in_channels = dim
# compute the sequences of channel sizes
sizes = [in_channels]
sizes += [
int(round(config.hidden * config.growth**k)) for k in range(config.depth)
]
params: dict[str, tp.Any]
params = dict(
kernel=config.kernel_size,
stride=1,
leakiness=config.relu_leakiness,
dropout=config.conv_dropout,
dropout_input=config.dropout_input,
batch_norm=config.batch_norm,
dilation_growth=config.dilation_growth,
groups=config.groups,
dilation_period=config.dilation_period,
skip=config.skip,
post_skip=config.post_skip,
scale=config.scale,
rewrite=config.rewrite,
glu=config.glu,
glu_context=config.glu_context,
glu_glu=config.glu_glu,
activation=activation,
)
final_channels = sizes[-1]
self.final: nn.Module | nn.Sequential | None = None
pad = 0
kernel = 1
stride = 1
if config.linear_out:
if config.complex_out:
raise ValueError("linear_out and complex_out are mutually exclusive")
self.final = nn.ConvTranspose1d(
final_channels, self.backbone_out_channels, kernel, stride, pad
)
elif config.complex_out:
self.final = nn.Sequential(
nn.Conv1d(final_channels, 2 * final_channels, 1),
activation(),
nn.ConvTranspose1d(
2 * final_channels, self.backbone_out_channels, kernel, stride, pad
),
)
else:
params["activation_on_last"] = False
sizes[-1] = self.backbone_out_channels
self.encoder = ConvSequence(sizes, **params)
self.transformer = None
if config.transformer_config:
self.transformer = config.transformer_config.build(
dim=self.backbone_out_channels
)
[docs]
def forward(
self,
x: torch.Tensor,
subject_ids: torch.Tensor | None = None,
channel_positions: torch.Tensor | None = None,
) -> torch.Tensor:
"""Run the convolutional encoder.
Parameters
----------
x : Tensor
Input of shape ``(B, C, T)``.
subject_ids : Tensor or None
Per-example subject indices, shape ``(B,)``.
channel_positions : Tensor or None
Normalised electrode coordinates, shape ``(B, C, D)``.
"""
length = x.shape[-1]
# if self.dropout is not None:
# x = self.dropout(x, batch)
if self.merger is not None:
x = self.merger(x, subject_ids, channel_positions)
if self.initial_linear is not None:
x = self.initial_linear(x)
if self.subject_layers is not None:
x = self.subject_layers(x, subject_ids)
x = self.encoder(x)
if self.final is not None:
x = self.final(x)
if x.shape[-1] < length:
raise ValueError(f"Expected output time dim >= {length}, got {x.shape[-1]}")
x = x[:, :, :length]
if self.transformer:
x = self.transformer(x.transpose(1, 2)).transpose(1, 2)
return x
[docs]
class SimpleConvTimeAgg(SimpleConv):
"""SimpleConv with temporal aggregation layer and optional output heads.
Parameters
----------
time_agg_out :
- "gap" : Global average pooling
- "linear" : Linear layer with one output
- "att" : Bahdanau attention layer
n_time_groups :
Number of groups within which to apply temporal aggregation, e.g. 4 means the time
dimension will be split into 4 groups and each group will be aggregated (and optionally
projected) separately.
"""
# Temporal aggregation
time_agg_out: tp.Literal["gap", "linear", "att"] = "gap"
n_time_groups: int | None = None
# Output head(s)
output_head_config: Mlp | dict[str, Mlp] | None = None
def build(self, n_in_channels: int, n_outputs: int) -> "SimpleConvTimeAggModel":
return SimpleConvTimeAggModel(n_in_channels, n_outputs, config=self)
[docs]
class SimpleConvTimeAggModel(SimpleConvModel):
"""``nn.Module`` implementation of :class:`SimpleConvTimeAgg`."""
def __init__(
self,
in_channels: int,
out_channels: int,
config: SimpleConvTimeAgg | None = None,
):
config = config if config is not None else SimpleConvTimeAgg()
super().__init__(
in_channels=in_channels, out_channels=out_channels, config=config
)
# Output aggregation layer
self.n_time_groups = config.n_time_groups
self.time_agg_out: nn.Module | None
if config.time_agg_out == "gap":
self.time_agg_out = nn.AdaptiveAvgPool1d(1)
elif config.time_agg_out == "linear":
self.time_agg_out = nn.LazyLinear(1)
elif config.time_agg_out == "att":
self.time_agg_out = BahdanauAttention(input_size=None, hidden_size=256)
else:
self.time_agg_out = None
# Separate output head(s)
self.output_head: nn.Module | None
if config.output_head_config is None:
self.output_head = None
else:
if self.time_agg_out is None:
raise NotImplementedError("Output heads require temporal aggregation.")
if isinstance(config.output_head_config, Mlp):
head_cfg, output_size = self._resolve_head_sentinel(
config.output_head_config
)
self.output_head = head_cfg.build(
input_size=self.backbone_out_channels,
output_size=output_size,
)
elif isinstance(config.output_head_config, dict):
self.output_head = nn.ModuleDict()
for name, head_config in config.output_head_config.items():
head_cfg, output_size = self._resolve_head_sentinel(head_config)
self.output_head[name] = head_cfg.build(
input_size=self.backbone_out_channels,
output_size=output_size,
)
def _resolve_head_sentinel(self, cfg: Mlp) -> tuple[Mlp, int | None]:
"""Strip the ``-1`` output-size sentinel from an ``Mlp`` config.
Returns a (possibly updated) config and the resolved ``output_size``.
"""
if cfg.hidden_sizes and cfg.hidden_sizes[-1] == -1:
cfg = cfg.model_copy(update={"hidden_sizes": cfg.hidden_sizes[:-1]})
return cfg, self.out_channels
return cfg, None
[docs]
def forward( # type: ignore
self, x, subject_ids=None, channel_positions=None
) -> torch.Tensor | dict[str, torch.Tensor]:
B = x.shape[0]
x = super().forward(
x, subject_ids=subject_ids, channel_positions=channel_positions
)
# Break down time dimension into groups, with padding that mirrors what
# ns.extractors.meta.TimeAggregatedExtractor does
if self.n_time_groups is not None:
x = x.permute(2, 0, 1) # (B, F, T) -> (T, B, F)
x = torch.tensor_split(
x, self.n_time_groups, dim=0
) # -> N x (~(T // N), B, F)
x = nn.utils.rnn.pad_sequence(x, batch_first=True) # -> (N, ~(T // N), B, F)
x = x.permute(0, 2, 3, 1) # -> (N, B, F, ~(T // N))
x = x.flatten(end_dim=1) # -> (NxB, F, ~(T // N))
if self.time_agg_out is not None:
x = self.time_agg_out(x)
if x.ndim == 3:
x = x.squeeze(2) # Remove singleton dimension
# Apply output heads (e.g. for separate CLIP and MSE losses)
if isinstance(self.output_head, nn.ModuleDict):
x = {name: head(x) for name, head in self.output_head.items()}
elif self.output_head is not None:
x = self.output_head(x)
# Build the time dimension back up
if self.n_time_groups is not None:
if isinstance(x, dict):
raise NotImplementedError
x = x.reshape(self.n_time_groups, B, -1) # (NxB, F) -> (N, B, F)
x = x.permute(1, 2, 0) # -> (B, N, F)
return x