API Reference

Some parts of neuraltrain depend on optional packages such as braindecode, x_transformers, green, or dalle2_pytorch. Install neuraltrain-repo/.[models] when you want the full set of model and augmentation modules available.

Core Interfaces

BaseModelConfig

Base class for model configurations.

BaseLoss

Base class for loss configurations.

BaseTorchLoss

Base class for torch loss configurations.

BaseMetric

Base class for metric configurations.

BaseTorchMetric

Base class for torchmetrics configurations.

BaseOptimizer

Base class for optimizer configurations.

LightningOptimizer

Pydantic configuration for Lightning optimizer.

Optimizers & Schedulers

BaseTorchOptimizer

Base class for torch optimizer configurations.

BaseLRScheduler

Base class for learning rate scheduler configurations.

BaseTorchLRScheduler

Base class for torch LR scheduler configurations.

Models

When braindecode is installed, every model it exports (EEGNet, ShallowFBCSPNet, Deep4Net, etc.) is auto-registered as a BaseBrainDecodeModel subclass on neuraltrain.models.base. Use models.EEGNet(kwargs={...}) — see Installation for an example.

BaseBrainDecodeModel

Base class for braindecode model configurations.

SimpleConv

1-D convolutional encoder, adapted from brainmagick.

SimpleConvModel

nn.Module implementation of SimpleConv.

SimpleConvTimeAgg

SimpleConv with temporal aggregation layer and optional output heads.

SimpleConvTimeAggModel

nn.Module implementation of SimpleConvTimeAgg.

SimplerConv

Convolutional encoder inspired by BENDR / wav2vec 2.0.

SimplerConvModel

Convolutional encoder inspired from BENDR/wav2vec2.0, with channel attention.

TransformerEncoder

Transformer encoder/decoder built on top of x_transformers.

ConvTransformer

Convolutional encoder followed by optional temporal aggregation and a transformer.

ConvTransformerModel

nn.Module implementation of ConvTransformer.

Conformer

Reference: Gulati et al., "Conformer: Convolution-augmented Transformer for Speech Recognition", Interspeech 2020.

Linear

Simple linear projection, with optional per-subject weights.

LinearModel

nn.Module implementation of Linear.

ConstantPredictor

Constant predictor that predicts the most frequent class or the mean of the targets.

ConstantPredictorModel

FmriMlp

Residual MLP for fMRI decoding, adapted from MindEye [1]_.

FmriMlpModel

Residual MLP adapted from [1].

FmriLinear

Single linear layer for fMRI decoding with temporal aggregation.

FmriLinearModel

nn.Module implementation of FmriLinear.

NtReve

Config for the braindecode REVE model with channel-mapping support.

NtLuna

Config for the braindecode LUNA model.

NtLabram

Config for the braindecode LaBraM model with pretrained-model support.

FreqBandNet

Parametrized filterbank feature extractor (sinc filters + power + log).

FreqBandNetModel

Simple parametrized filterbank feature extractor (bandpass filters + power extraction + log nonlinearity + optional MLP output head).

Green

Reference: GREEN: A lightweight architecture using learnable wavelets and Riemannian geometry for biomarker exploration with EEG signals.

DiffusionPrior

Diffusion prior module adapted from MindEye [1]_.

OnTheFlyPreprocessor

Module to apply common preprocessing steps on-the-fly, inside an nn.Module.

OnTheFlyPreprocessorModel

nn.Module implementation of OnTheFlyPreprocessor.

Shared Building Blocks

SubjectLayers

Configuration for per-subject linear projections.

FourierEmb

Configuration for Fourier positional embedding.

ChannelMerger

Configuration for the ChannelMerger module.

Mlp

Multilayer perceptron, e.g. for use as projection head.

TemporalDownsampling

Temporal downsampling via a 2-D convolution over the time axis.

LayerScale

Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).

NormDenormScaler

Norm-denorm scaler inspired by [1]_.

ChannelDropout

BahdanauAttention

Bahdanau attention from [1]_.

Losses

MultiLoss

Weighted combination of multiple loss terms.

ClipLoss

CLIP constrastive loss.

SigLipLoss

SigLIP contrastive loss.

Metrics

OnlinePearsonCorr

Online Pearson correlation coefficient.

Rank

Rank of predictions based on a retrieval set, using cosine similarity.

TopkAcc

Top-k accuracy.

TopkAccFromScores

Top-k accuracy computed from already available similarity scores.

ImageSimilarity

Image similarity metric based on feature extraction from a pretrained network.

GroupedMetric

A wrapper around a torchmetrics.Metric that allows for computing metrics per group.

Augmentations

TrivialBrainAugment

Inspired by TrivialAugment [1], sample augmentations and strength randomly on each minibatch/forward pass.

TrivialBrainAugmentConfig

BandstopFilterFFT

Bandstop data augmentation, applying a bandstop filter to the data using Fourier transform.

BandstopFilterFFTConfig

Configuration for BandstopFilterFFT.

ChannelsDropoutConfig

Configuration for braindecode's ChannelsDropout augmentation.

FrequencyShiftConfig

Configuration for braindecode's FrequencyShift augmentation.

GaussianNoiseConfig

Configuration for braindecode's GaussianNoise augmentation.

SmoothTimeMaskConfig

Configuration for braindecode's SmoothTimeMask augmentation.

Experiment Utilities

run_grid

Run grid over provided experiment.

CsvLoggerConfig

Pydantic configuration for torch-lightning's CSVLogger.

WandbLoggerConfig

Pydantic configuration for torch-lightning's wandb logger.

WandbInfra

BaseExperiment

Base experiment class which require an infra and a 'run' method.

StandardScaler

Standard scaler that can be fitted by batch and handles 2-dimensional extractors.

TimedIterator

Keeps last fetch durations of the iterator, as well as last call to call durations.

convert_to_pydantic

Converts any class into a pydantic BaseModel.