Source code for neuraltrain.models.simplerconv

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
from math import prod

import torch
from torch import nn

from .base import BaseModelConfig
from .common import ChannelMerger, SubjectLayers

logger = logging.getLogger(__name__)


[docs] class SimplerConv(BaseModelConfig): """Convolutional encoder inspired by BENDR / wav2vec 2.0. Parameters ---------- subject_layers_config : SubjectLayers or None If set, prepend a per-subject linear projection. merger_config : ChannelMerger or None If set, prepend a :class:`ChannelMerger` for multi-montage support. n_hidden_channels : int Number of hidden channels in each convolutional layer. kernel_sizes : tuple of int Kernel size for each convolutional layer. strides : tuple of int Stride for each convolutional layer. Must have the same length as *kernel_sizes*. output_nonlin : bool Apply GELU after the last convolutional layer. dropout : float Channel-dropout (``Dropout1d``) probability after each convolution. """ # Subject specific layer subject_layers_config: SubjectLayers | None = None # Channel merger to allow learning on different montages merger_config: ChannelMerger | None = None # Convolutional encoder # NOTE: Default values are inspired from BENDR/wav2vec2.0 n_hidden_channels: int = 512 kernel_sizes: tuple[int, ...] = (3, 2, 2, 2, 2) strides: tuple[int, ...] = (3, 2, 2, 2, 2) output_nonlin: bool = True dropout: float = 0.0 def build(self, n_in_channels: int, n_outputs: int | None = None) -> nn.Module: return SimplerConvModel(n_in_channels, n_outputs, config=self)
[docs] class SimplerConvModel(nn.Module): """Convolutional encoder inspired from BENDR/wav2vec2.0, with channel attention.""" def __init__( self, in_channels: int, out_channels: int | None, config: SimplerConv | None = None, ): super().__init__() config = config if config is not None else SimplerConv() if len(config.kernel_sizes) != len(config.strides): nk = len(config.kernel_sizes) ns = len(config.strides) raise ValueError(f"kernel_sizes/strides length mismatch: {nk} vs {ns}") self.kernel_sizes = config.kernel_sizes self.strides = config.strides self.out_channels = out_channels self.merger = None if config.merger_config is not None: self.merger = config.merger_config.build() in_channels = config.merger_config.n_virtual_channels self.subject_layers = None if config.subject_layers_config: self.subject_layers = config.subject_layers_config.build( in_channels, config.n_hidden_channels ) in_channels = config.n_hidden_channels # Convolutional encoder self.encoder = nn.ModuleList() n_layers = len(self.kernel_sizes) last_out_channels = ( config.n_hidden_channels if self.out_channels is None else self.out_channels ) for i, (kernel_size, stride) in enumerate(zip(self.kernel_sizes, self.strides)): out_channels = ( config.n_hidden_channels if i < n_layers - 1 else last_out_channels ) if i < n_layers - 1: out_channels = config.n_hidden_channels nonlin = nn.GELU() else: out_channels = last_out_channels nonlin = nn.GELU() if config.output_nonlin else nn.Identity() # type: ignore self.encoder.append( nn.Sequential( nn.Conv1d( in_channels if i == 0 else config.n_hidden_channels, out_channels, kernel_size=kernel_size, stride=stride, ), nn.Dropout1d(config.dropout), # Channel dropout nn.GroupNorm( out_channels // 2, out_channels ), # Like in BENDR, see https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py#L386 nonlin, ) ) @property def downsampling_factor(self) -> int: return prod(self.strides) @property def receptive_field(self) -> int: """Compute the receptive field of the encoder. See https://distill.pub/2019/computing-receptive-fields/, Eq. 2 """ cumprod_strides = [1] + torch.cumprod( torch.tensor(self.strides[:-1]), dim=0 ).tolist() return int( sum((k - 1) * s for k, s in zip(self.kernel_sizes, cumprod_strides)) + 1 )
[docs] def forward( self, x: torch.Tensor, subject_ids: torch.Tensor | None = None, channel_positions: torch.Tensor | None = None, ) -> torch.Tensor: """Run the convolutional encoder. Parameters ---------- x : Tensor Input of shape ``(B, C, T)``. subject_ids : Tensor or None Per-example subject indices, shape ``(B,)``. channel_positions : Tensor or None Normalised electrode coordinates, shape ``(B, C, D)``. """ if self.merger is not None: x = self.merger(x, subject_ids, channel_positions) if self.subject_layers is not None: x = self.subject_layers(x, subject_ids) for block in self.encoder: x = block(x) return x