neuraltrain

neuraltrain

neuraltrain turns neuralset datasets into trained PyTorch models. Define models, losses, metrics, and optimizers as typed pydantic configs, wrap them in an experiment, and launch local runs or Slurm sweeps – all serializable, validated, and reproducible.

pip install neuraltrain

For the full model zoo, Lightning support, and dev tools:

pip install 'neuraltrain[dev,lightning,models]'

Quickstart

Define training pieces as config objects, then build concrete PyTorch modules at runtime.

import torch

from neuraltrain.losses import base as losses
from neuraltrain.metrics import base as metrics
from neuraltrain import models
from neuraltrain.optimizers import base as optimizers

# Model
model_cfg = models.SimpleConvTimeAgg(hidden=32, depth=4, merger_config=None)
model = model_cfg.build(n_in_channels=208, n_outputs=4)

# Loss & metrics
loss_cfg = losses.CrossEntropyLoss()
metric_cfg = metrics.Accuracy(
    log_name="acc",
    kwargs={"task": "multiclass", "num_classes": 4},
)

# Optimizer
optim_cfg = optimizers.LightningOptimizer(
    optimizer=optimizers.Adam(lr=1e-4),
    scheduler=optimizers.OneCycleLR(
        kwargs={"max_lr": 3e-3, "pct_start": 0.2},
    ),
)

x = torch.randn(8, 208, 120)
y = torch.randint(0, 4, (8,))

logits = model(x)
loss = loss_cfg.build()(logits, y)

metric = metric_cfg.build()
metric.update(logits, y)

optimizer_bundle = optim_cfg.build(model.parameters(), total_steps=100)

print(logits.shape)
print(float(loss))
print(metric.compute())
print(sorted(optimizer_bundle))

Tutorials

Each tutorial covers one stage of the training pipeline.

Data
Use neuralset studies and a Segmenter to build train/val/test loaders.
events = self.study.run()
dataset = self.segmenter.apply(events)
dataset.prepare()
loaders = {split: DataLoader(ds, ...)
           for split, ds in ...}
Model Config
Define serializable model configs and build the PyTorch module when shapes are known.
from neuraltrain import models

model_cfg = models.SimpleConvTimeAgg(
    hidden=32, depth=4,
    merger_config=None)
model = model_cfg.build(
    n_in_channels=208,
    n_outputs=4)
Objective
Compose losses, metrics, optimizers, and schedulers as typed config objects.
loss = CrossEntropyLoss()
metric = Accuracy(
    log_name="acc",
    kwargs={"task": "multiclass",
            "num_classes": 4})
optim = LightningOptimizer(...)
Trainer
Wrap the model in a Lightning module with train, validation, and test loops.
brain_module = BrainModule(
    model=brain_model,
    loss=self.loss.build(),
    optim_config=self.optim,
    metrics={m.log_name: m.build()
             for m in self.metrics})