Source code for neuraltrain.models.green

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

from torch import nn

from .base import BaseModelConfig


[docs] class Green(BaseModelConfig): """ Reference: GREEN: A lightweight architecture using learnable wavelets and Riemannian geometry for biomarker exploration with EEG signals. See https://arxiv.org/pdf/1808.00158. Parameters ---------- sfreq: int | None, default=None Sampling frequency of the signal. n_freqs: int, default=15 Number of main frequencies in the wavelet family. kernel_width_s: int, default=5 Width of the kernel in seconds for the wavelets. conv_stride: int, default=5 Stride of the convolution operation for the wavelets. oct_min: float, default=0 Minimum frequency of interest in octave. oct_max: float, default=5.5 Maximum frequency of interest in octave. random_f_init: bool, default=False Whether to randomly initialize the frequency of interest. shrinkage_init: float, default=-3.0 Initial shrinkage value before applying sigmoid function. logref: str, default='logeuclid' Reference matrix used for LogEig layer. dropout: float, default=0.333 Dropout rate for FC layers. hidden_dim: list[int], default=[32] Dimension of the hidden layer. If None, no hidden layer. pool_layer: str, default='RealCovariance' Pooling layer type. Options: 'RealCovariance', 'PW_PLV', 'CombinedPooling', 'CrossCovariance', 'CrossPW_PLV', 'WaveletConv'. bi_out: list[int] | None, default=None Dimension of the output layer after BiMap. use_age: bool, default=False Whether to include age in the model. orth_weights: bool, default=True Whether to use orthogonal weight initialization. """ sfreq: int | None = None n_freqs: int = 15 kernel_width_s: int = 5 conv_stride: int = 5 oct_min: float = 0 oct_max: float = 5.5 random_f_init: bool = False shrinkage_init: float = -3.0 logref: str = "logeuclid" dropout: float = 0.333 hidden_dim: list[int] = [32] pool_layer: tp.Literal[ "RealCovariance", "PW_PLV", "CombinedPooling", "CrossCovariance", "CrossPW_PLV", "WaveletConv", ] = "RealCovariance" bi_out: list[int] | None = None use_age: bool = False orth_weights: bool = True def build( self, n_in_channels: int, n_outputs: int, sfreq: int | None = None ) -> nn.Module: import green.wavelet_layers as wl # type: ignore from green.research_code.pl_utils import get_green # type: ignore kwargs = self.model_dump() del kwargs["name"] sfreq = sfreq or self.sfreq if sfreq is None: raise ValueError("sfreq must be provided to build the model.") kwargs["sfreq"] = sfreq kwargs["pool_layer"] = getattr(wl, self.pool_layer)() kwargs["out_dim"] = n_outputs kwargs["n_ch"] = n_in_channels return get_green(**kwargs)