Source code for neuraltrain.models.conv_transformer
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
from torch import nn
from .base import BaseModelConfig
from .common import TemporalDownsampling
from .conformer import Conformer
from .simpleconv import SimpleConv
from .simplerconv import SimplerConv
from .transformer import TransformerEncoder
logger = logging.getLogger(__name__)
[docs]
class ConvTransformer(BaseModelConfig):
"""Convolutional encoder followed by optional temporal aggregation and a transformer.
Parameters
----------
dim :
Internal token dimension.
encoder_config :
Configuration for the convolutional encoder.
temporal_downsampling_config :
Configuration for the optional temporal downsampling module.
conv_pos_emb_kernel_size :
If provided, use convolutional positional embedding with this kernel size.
neuro_device_types :
List of expected neuro device types that can be used to embed the device type in the
transformer.
add_cls_token :
If True, add a [CLS] token to the input of the transformer.
pre_transformer_layer_norm :
If True, apply layer normalization before the transformer.
transformer_config :
Configuration for the transformer encoder.
output_avg_pool :
If True, average the tokens outputted by the transformer.
output_layer_dim :
Set to 0 for no output layer, or None to use the same dimension as the transformer. Of
note, both Bendr and Wav2vec2.0 use an output linear projection though it's not mentioned
in their respective papers.
"""
dim: int = 512
encoder_config: SimplerConv | SimpleConv
temporal_downsampling_config: TemporalDownsampling | None = None
conv_pos_emb_kernel_size: int | None = None
neuro_device_types: list[str] | None = None
add_cls_token: bool = False
pre_transformer_layer_norm: bool = False
transformer_config: TransformerEncoder | Conformer | None = None
output_avg_pool: bool = False
output_layer_dim: int | None = 0
[docs]
def build(
self, n_in_channels: int, n_outputs: int | None = None
) -> "ConvTransformerModel":
"""Build ConvTransformer model.
Parameters
----------
n_in_channels :
Number of input channels.
n_outputs :
Number of output dimensions. If None, use the `output_layer_dim` parameter from the
config.
"""
return ConvTransformerModel(
n_in_channels,
n_outputs or self.output_layer_dim,
config=self,
)
[docs]
class ConvTransformerModel(nn.Module):
"""``nn.Module`` implementation of :class:`ConvTransformer`."""
def __init__(
self,
in_channels: int,
out_channels: int | None, # For output linear layer
config: ConvTransformer,
):
super().__init__()
# Encoder
self.dim = config.dim
self.encoder = config.encoder_config.build(in_channels, self.dim)
# Temporal downsampling
self.temporal_downsampling = None
if config.temporal_downsampling_config is not None:
self.temporal_downsampling = config.temporal_downsampling_config.build(
self.dim
)
self.pre_transformer_layer_norm = None
if config.pre_transformer_layer_norm:
self.pre_transformer_layer_norm = nn.LayerNorm(self.dim)
# Transformer
self.transformer = None
if config.transformer_config is not None:
self.transformer = config.transformer_config.build(dim=self.dim)
# [CLS] token
self.cls_token = None
if config.add_cls_token:
self.cls_token = nn.Parameter(
torch.empty(
(
1,
1,
self.dim,
)
),
requires_grad=True,
)
nn.init.normal_(self.cls_token, mean=0.0, std=1.0)
# Positional embedding
# See https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/models/wav2vec2/position_encoder.py
# See https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py#L522
self.rel_pos_emb = None
if config.conv_pos_emb_kernel_size is not None:
kernel_size = config.conv_pos_emb_kernel_size
conv = nn.Conv1d(
self.dim,
self.dim,
kernel_size,
padding="same",
groups=16, # XXX Parametrize
)
nn.init.normal_(
conv.weight, mean=0.0, std=(4.0 / (kernel_size * self.dim)) ** 0.5
)
nn.init.constant_(conv.bias, 0.0) # type: ignore
conv = nn.utils.weight_norm(
conv, dim=2
) # XXX Will be deprecated in favour of parametrizations, but not yet compatible
# conv = nn.utils.parametrizations.weight_norm(conv, dim=2)
self.rel_pos_emb = nn.Sequential(conv, nn.GELU())
# Device embedding
self.neuro_device_types = None
if config.neuro_device_types is not None:
self.neuro_device_types = sorted(config.neuro_device_types)
self.neuro_device_emb = nn.Embedding(
len(self.neuro_device_types), self.dim
)
# Output layer
self.output_avg_pool = config.output_avg_pool
self.output_layer = None
if out_channels != 0:
self.output_layer = nn.Linear(
self.dim,
out_channels or self.dim,
)
def _encoder_and_downsampling_forward(
self,
x: torch.Tensor,
subject_ids: torch.Tensor | None = None,
channel_positions: torch.Tensor | None = None,
):
z = self.encoder.forward(
x, subject_ids=subject_ids, channel_positions=channel_positions
)
z = z.transpose(2, 1) # (B, F, T) -> (B, T, F)
if self.temporal_downsampling is not None:
z = self.temporal_downsampling(z.unsqueeze(dim=1)).squeeze(dim=1)
return z
def _pre_transformer_forward(
self,
z: torch.Tensor,
neuro_device_type: str | None = None,
) -> torch.Tensor:
if self.rel_pos_emb is not None:
pos_emb = self.rel_pos_emb(z.transpose(2, 1)).transpose(2, 1)
z = z + pos_emb
if neuro_device_type is not None and self.neuro_device_types is not None:
device_ind = torch.tensor(
self.neuro_device_types.index(neuro_device_type), device=z.device
)
device_emb = self.neuro_device_emb(device_ind).unsqueeze(0).unsqueeze(0)
z = z + device_emb
if self.pre_transformer_layer_norm is not None:
z = self.pre_transformer_layer_norm(z)
if self.cls_token is not None:
z = torch.cat(
[
self.cls_token.repeat((z.shape[0], 1, 1)),
z,
],
dim=1,
)
return z
[docs]
def forward( # type: ignore
self,
x: torch.Tensor,
subject_ids: torch.Tensor | None = None,
channel_positions: torch.Tensor | None = None,
neuro_device_type: str | None = None,
) -> dict[str, torch.Tensor]:
"""Forward pass through encoder, optional transformer, and output layer.
Parameters
----------
x : Tensor
Input of shape ``(B, C, T)``.
subject_ids : Tensor or None
Per-example subject indices, shape ``(B,)``.
channel_positions : Tensor or None
Normalised electrode coordinates, shape ``(B, C, D)``.
neuro_device_type : str or None
Name of device represented in the batch (e.g. ``"Eeg"``,
``"Meg"``). If ``None``, the device embedding is not applied.
"""
z = self._encoder_and_downsampling_forward(
x, subject_ids=subject_ids, channel_positions=channel_positions
)
if self.transformer is None:
c_out = z
else:
c_in = self._pre_transformer_forward(
z,
neuro_device_type=neuro_device_type,
)
c_out = self.transformer(c_in)
if self.output_avg_pool:
if c_out.ndim != 3:
raise ValueError(f"Expected 3D tensor for avg pool, got {c_out.ndim}D")
c_out = c_out.mean(dim=1)
if self.output_layer is not None:
c_out = self.output_layer(c_out)
return {
"z": z,
"c_out": c_out,
}