neuraltrain.models.green.Green

class neuraltrain.models.green.Green(*, sfreq: int | None = None, n_freqs: int = 15, kernel_width_s: int = 5, conv_stride: int = 5, oct_min: float = 0, oct_max: float = 5.5, random_f_init: bool = False, shrinkage_init: float = -3.0, logref: str = 'logeuclid', dropout: float = 0.333, hidden_dim: list[int] = [32], pool_layer: Literal['RealCovariance', 'PW_PLV', 'CombinedPooling', 'CrossCovariance', 'CrossPW_PLV', 'WaveletConv'] = 'RealCovariance', bi_out: list[int] | None = None, use_age: bool = False, orth_weights: bool = True)[source][source]

Reference: GREEN: A lightweight architecture using learnable wavelets and Riemannian geometry for biomarker exploration with EEG signals. See https://arxiv.org/pdf/1808.00158.

Parameters:
  • sfreq (int | None, default=None) – Sampling frequency of the signal.

  • n_freqs (int, default=15) – Number of main frequencies in the wavelet family.

  • kernel_width_s (int, default=5) – Width of the kernel in seconds for the wavelets.

  • conv_stride (int, default=5) – Stride of the convolution operation for the wavelets.

  • oct_min (float, default=0) – Minimum frequency of interest in octave.

  • oct_max (float, default=5.5) – Maximum frequency of interest in octave.

  • random_f_init (bool, default=False) – Whether to randomly initialize the frequency of interest.

  • shrinkage_init (float, default=-3.0) – Initial shrinkage value before applying sigmoid function.

  • logref (str, default='logeuclid') – Reference matrix used for LogEig layer.

  • dropout (float, default=0.333) – Dropout rate for FC layers.

  • hidden_dim (list[int], default=[32]) – Dimension of the hidden layer. If None, no hidden layer.

  • pool_layer (str, default='RealCovariance') – Pooling layer type. Options: ‘RealCovariance’, ‘PW_PLV’, ‘CombinedPooling’, ‘CrossCovariance’, ‘CrossPW_PLV’, ‘WaveletConv’.

  • bi_out (list[int] | None, default=None) – Dimension of the output layer after BiMap.

  • use_age (bool, default=False) – Whether to include age in the model.

  • orth_weights (bool, default=True) – Whether to use orthogonal weight initialization.