neuraltrain.models.green.Green

pydantic model neuraltrain.models.green.Green[source][source]

Reference: GREEN: A lightweight architecture using learnable wavelets and Riemannian geometry for biomarker exploration with EEG signals. See https://arxiv.org/pdf/1808.00158.

Parameters:
  • sfreq (int | None, default=None) – Sampling frequency of the signal.

  • n_freqs (int, default=15) – Number of main frequencies in the wavelet family.

  • kernel_width_s (int, default=5) – Width of the kernel in seconds for the wavelets.

  • conv_stride (int, default=5) – Stride of the convolution operation for the wavelets.

  • oct_min (float, default=0) – Minimum frequency of interest in octave.

  • oct_max (float, default=5.5) – Maximum frequency of interest in octave.

  • random_f_init (bool, default=False) – Whether to randomly initialize the frequency of interest.

  • shrinkage_init (float, default=-3.0) – Initial shrinkage value before applying sigmoid function.

  • logref (str, default='logeuclid') – Reference matrix used for LogEig layer.

  • dropout (float, default=0.333) – Dropout rate for FC layers.

  • hidden_dim (list[int], default=[32]) – Dimension of the hidden layer. If None, no hidden layer.

  • pool_layer (str, default='RealCovariance') – Pooling layer type. Options: ‘RealCovariance’, ‘PW_PLV’, ‘CombinedPooling’, ‘CrossCovariance’, ‘CrossPW_PLV’, ‘WaveletConv’.

  • bi_out (list[int] | None, default=None) – Dimension of the output layer after BiMap.

  • use_age (bool, default=False) – Whether to include age in the model.

  • orth_weights (bool, default=True) – Whether to use orthogonal weight initialization.

Fields:
field sfreq: int | None = None[source]
field n_freqs: int = 15[source]
field kernel_width_s: int = 5[source]
field conv_stride: int = 5[source]
field oct_min: float = 0[source]
field oct_max: float = 5.5[source]
field random_f_init: bool = False[source]
field shrinkage_init: float = -3.0[source]
field logref: str = 'logeuclid'[source]
field dropout: float = 0.333[source]
field hidden_dim: list[int] = [32][source]
field pool_layer: Literal['RealCovariance', 'PW_PLV', 'CombinedPooling', 'CrossCovariance', 'CrossPW_PLV', 'WaveletConv'] = 'RealCovariance'[source]
field bi_out: list[int] | None = None[source]
field use_age: bool = False[source]
field orth_weights: bool = True[source]
build(n_in_channels: int, n_outputs: int, sfreq: int | None = None) Module[source][source]