API Reference¶
Some parts of neuraltrain depend on optional packages such as
braindecode, x_transformers, green, or dalle2_pytorch.
Install neuraltrain-repo/.[models] when you want the full set of model
and augmentation modules available.
Core Interfaces¶
Base class for model configurations. |
Base class for loss configurations. |
|
Base class for torch loss configurations. |
Base class for metric configurations. |
|
Base class for torchmetrics configurations. |
Base class for optimizer configurations. |
|
Pydantic configuration for Lightning optimizer. |
Optimizers & Schedulers¶
Base class for torch optimizer configurations. |
|
Base class for learning rate scheduler configurations. |
|
Base class for torch LR scheduler configurations. |
Models¶
When braindecode is installed, every model it exports (EEGNet,
ShallowFBCSPNet, Deep4Net, etc.) is auto-registered as a
BaseBrainDecodeModel subclass on neuraltrain.models.base.
Use models.EEGNet(kwargs={...}) — see Installation for an example.
Base class for braindecode model configurations. |
1-D convolutional encoder, adapted from brainmagick. |
|
|
|
SimpleConv with temporal aggregation layer and optional output heads. |
|
|
Convolutional encoder inspired by BENDR / wav2vec 2.0. |
|
Convolutional encoder inspired from BENDR/wav2vec2.0, with channel attention. |
Transformer encoder/decoder built on top of |
Convolutional encoder followed by optional temporal aggregation and a transformer. |
|
|
Reference: Gulati et al., "Conformer: Convolution-augmented Transformer for Speech Recognition", Interspeech 2020. |
Simple linear projection, with optional per-subject weights. |
|
|
Constant predictor that predicts the most frequent class or the mean of the targets. |
|
Residual MLP for fMRI decoding, adapted from MindEye [1]_. |
|
Residual MLP adapted from [1]. |
|
Single linear layer for fMRI decoding with temporal aggregation. |
|
|
Config for the braindecode REVE model with channel-mapping support. |
Config for the braindecode LUNA model. |
Config for the braindecode LaBraM model with pretrained-model support. |
Parametrized filterbank feature extractor (sinc filters + power + log). |
|
Simple parametrized filterbank feature extractor (bandpass filters + power extraction + log nonlinearity + optional MLP output head). |
Reference: GREEN: A lightweight architecture using learnable wavelets and Riemannian geometry for biomarker exploration with EEG signals. |
Diffusion prior module adapted from MindEye [1]_. |
Module to apply common preprocessing steps on-the-fly, inside an nn.Module. |
|
|
Losses¶
Weighted combination of multiple loss terms. |
|
CLIP constrastive loss. |
|
SigLIP contrastive loss. |
Metrics¶
Online Pearson correlation coefficient. |
|
Rank of predictions based on a retrieval set, using cosine similarity. |
|
Top-k accuracy. |
|
Top-k accuracy computed from already available similarity scores. |
|
Image similarity metric based on feature extraction from a pretrained network. |
|
A wrapper around a torchmetrics.Metric that allows for computing metrics per group. |
Augmentations¶
Inspired by TrivialAugment [1], sample augmentations and strength randomly on each minibatch/forward pass. |
|
Bandstop data augmentation, applying a bandstop filter to the data using Fourier transform. |
|
Configuration for |
|
Configuration for braindecode's |
|
Configuration for braindecode's |
|
Configuration for braindecode's |
|
Configuration for braindecode's |
Experiment Utilities¶
Run grid over provided experiment. |
|
Pydantic configuration for torch-lightning's CSVLogger. |
|
Pydantic configuration for torch-lightning's wandb logger. |
|
Base experiment class which require an infra and a 'run' method. |
|
Standard scaler that can be fitted by batch and handles 2-dimensional extractors. |
|
Keeps last fetch durations of the iterator, as well as last call to call durations. |
|
Converts any class into a pydantic BaseModel. |