Note
Go to the end to download the full example code.
Trainer and Experiments¶
This tutorial shows how to wire configs into a Lightning training loop and package everything into a reproducible experiment.
Note
This tutorial requires the lightning extra:
pip install 'neuraltrain[lightning]'
See also
Project Example – full training project
Building the training pieces¶
All training pieces start as typed configs:
import torch
import torch.nn as nn
from neuraltrain import BaseLoss, BaseMetric, LightningOptimizer, models
model_config = models.SimpleConvTimeAgg(hidden=32, depth=4, merger_config=None)
loss_config = BaseLoss(name="CrossEntropyLoss")
metric_config = BaseMetric(
name="Accuracy",
log_name="acc",
kwargs={"task": "multiclass", "num_classes": 4},
)
optim_config = LightningOptimizer(
optimizer={"name": "Adam", "lr": 1e-4},
scheduler={"name": "OneCycleLR", "kwargs": {"max_lr": 3e-3, "pct_start": 0.2}},
)
brain_model = model_config.build(n_in_channels=204, n_outputs=4).cpu()
loss = loss_config.build()
metrics = {metric_config.log_name: metric_config.build()}
print("Model: ", type(brain_model).__name__)
print("Loss: ", loss)
print("Metrics: ", metrics)
print("Optimiser:", optim_config)
Model: SimpleConvTimeAggModel
Loss: CrossEntropyLoss()
Metrics: {'acc': MulticlassAccuracy()}
Optimiser: optimizer=Adam(lr=0.0001, kwargs={}) scheduler=OneCycleLR(kwargs={'max_lr': 0.003, 'pct_start': 0.2}) interval='step'
BrainModule¶
BrainModule is a LightningModule that wires the built pieces
into a training loop. It lives in project_example/pl_module.py
– you own the full training logic.
Key methods:
forward– extractsbatch.data[x_name]and calls the model_run_step– computes loss, updates metrics, logs bothconfigure_optimizers– builds the optimiser from config, passingtotal_stepsfor schedulers that need it
Here is the full implementation:
import lightning.pytorch as pl
from torchmetrics import Metric
from neuraltrain.optimizers import BaseOptimizer
class BrainModule(pl.LightningModule):
def __init__(
self,
model: nn.Module,
loss: nn.Module,
optim_config: BaseOptimizer,
metrics: dict[str, Metric],
x_name: str = "input",
y_name: str = "target",
max_epochs: int = 100,
) -> None:
super().__init__()
self.model = model
self.x_name, self.y_name = x_name, y_name
self.optim_config = optim_config
self.max_epochs = max_epochs
self.loss = loss
self.metrics = nn.ModuleDict(
{split + "_" + k: v for k, v in metrics.items() for split in ["val", "test"]}
)
def forward(self, batch):
return self.model(batch.data[self.x_name])
def _run_step(self, batch, batch_idx, step_name):
y_true = batch.data[self.y_name].squeeze(-1)
y_pred = self.forward(batch)
loss = self.loss(y_pred, y_true)
self.log(f"{step_name}_loss", loss, on_epoch=True, prog_bar=True)
for name, metric in self.metrics.items():
if name.startswith(step_name):
metric.update(y_pred, y_true)
self.log(name, metric, on_epoch=True, prog_bar=True)
return loss, y_pred, y_true
def training_step(self, batch, batch_idx):
loss, _, _ = self._run_step(batch, batch_idx, step_name="train")
return loss
def validation_step(self, batch, batch_idx):
_, y_pred, y_true = self._run_step(batch, batch_idx, step_name="val")
return y_pred, y_true
def test_step(self, batch, batch_idx):
self._run_step(batch, batch_idx, step_name="test")
def configure_optimizers(self):
try:
return self.optim_config.build(
self.parameters(),
total_steps=self.trainer.estimated_stepping_batches,
)
except TypeError:
return self.optim_config.build(self.parameters())
Let’s instantiate it and run a forward pass:
brain_module = BrainModule(
model=brain_model,
loss=loss,
optim_config=optim_config,
metrics=metrics,
max_epochs=20,
)
x = torch.randn(8, 204, 60)
target = torch.randint(0, 4, (8,))
brain_module.eval()
with torch.no_grad():
y_pred = brain_module.model(x)
loss_val = brain_module.loss(y_pred, target)
print(f"Predictions: {tuple(y_pred.shape)}")
print(f"Loss: {loss_val.item():.4f}")
Predictions: (8, 4)
Loss: 1.3784
The Experiment class¶
Experiment is a pydantic.BaseModel that collects all
training pieces – data, model, loss, optimiser, metrics, loggers,
and infrastructure – in one validated object:
class Experiment(pydantic.BaseModel):
data: Data
brain_model_config: BaseModelConfig
loss: BaseLoss
optim: LightningOptimizer
metrics: list[BaseMetric]
csv_config: CsvLoggerConfig | None = None
wandb_config: WandbLoggerConfig | None = None
infra: TaskInfra = TaskInfra(version="1")
Its run() method orchestrates the full pipeline:
@infra.apply
def run(self) -> dict[str, float | None]:
pl.seed_everything(self.seed, workers=True)
loaders = self.data.build()
brain_module = self._build_brain_module(loaders["train"])
trainer = self._setup_trainer()
if not self.test_only:
self.fit(brain_module, trainer, loaders["train"], loaders["val"])
return self.test(brain_module, trainer, loaders["test"])
The @infra.apply decorator makes run() cacheable and
Slurm-submittable.
TaskInfra¶
exca.TaskInfra controls where and how the experiment runs:
import exca
infra = exca.TaskInfra(
cluster=None,
folder="/tmp/results/mne_sample_clf",
gpus_per_node=1,
cpus_per_task=10,
)
print(infra)
folder='/tmp/results/mne_sample_clf' cluster=None logs='{folder}/logs/{user}/%j' job_name=None timeout_min=None nodes=1 tasks_per_node=1 cpus_per_task=10 gpus_per_node=1 mem_gb=None max_pickle_size_gb=None slurm_constraint=None slurm_partition=None slurm_account=None slurm_qos=None slurm_use_srun=False slurm_additional_parameters=None conda_env=None workdir=None permissions=511 version='0' mode='cached' keep_in_ram=False
Key settings:
cluster=None– run locallycluster="auto"– submit to Slurm if availablefolder– output directory and cache root
Hyperparameter sweeps¶
run_grid() expands a parameter grid into experiment configs:
from neuraltrain.utils import run_grid
results = run_grid(
Experiment, "hp_search", default_config,
grid, combinatorial=True,
)
grid = {
"brain_model_config.hidden": [32, 64],
"n_epochs": [20, 50],
"seed": [33, 87],
}
n_combinations = 1
for key, values in grid.items():
n_combinations *= len(values)
print(f" {key}: {values}")
print(f"\nTotal combinations: {n_combinations}")
brain_model_config.hidden: [32, 64]
n_epochs: [20, 50]
seed: [33, 87]
Total combinations: 8
Callbacks and loggers¶
The experiment’s _setup_trainer() wires Lightning callbacks and
loggers:
from lightning.pytorch.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from neuraltrain import utils
callbacks = [
EarlyStopping(monitor="val_loss", mode="min", patience=5),
LearningRateMonitor(logging_interval="epoch"),
ModelCheckpoint(
save_last=True,
save_top_k=1,
dirpath="/tmp/checkpoints",
filename="best",
monitor="val_loss",
),
]
for cb in callbacks:
print(type(cb).__name__)
csv_config = utils.CsvLoggerConfig(name="mne_sample_clf")
wandb_config = utils.WandbLoggerConfig(
log_model=False, group="mne_sample_clf", project="mne_sample_clf"
)
print("CSV: ", csv_config)
print("W&B: ", wandb_config)
EarlyStopping
LearningRateMonitor
ModelCheckpoint
CSV: name='mne_sample_clf' version=None prefix='' flush_logs_every_n_steps=100
W&B: name=None group='mne_sample_clf' entity=None project='mne_sample_clf' offline=False host=None id=None dir=None anonymous=None log_model=False experiment=None prefix='' resume='allow'
Total running time of the script: (0 minutes 0.111 seconds)