Note
Go to the end to download the full example code.
Metrics and Optimisation¶
Losses, metrics, and optimisers are stored as typed config objects, just like models. That keeps the whole training recipe serialisable and sweepable.
This tutorial covers:
configuring losses, metrics, and optimisers
building runtime objects from configs
scheduler integration
Loss config¶
BaseLoss is a discriminated pydantic model. All standard PyTorch
losses are auto-registered, plus custom losses like ClipLoss and
SigLipLoss.
Built: CrossEntropyLoss()
Metric config¶
BaseMetric is a discriminated pydantic model. All standard torchmetrics
are auto-registered, plus custom metrics like Rank and
ImageSimilarity.
from neuraltrain import BaseMetric
metric_config = BaseMetric(
name="Accuracy",
log_name="acc",
kwargs={"task": "multiclass", "num_classes": 4},
)
metric = metric_config.build()
print("Built: ", metric)
Built: MulticlassAccuracy()
Optimiser config¶
LightningOptimizer bundles a torch optimiser with an optional
scheduler. All torch.optim optimisers and
torch.optim.lr_scheduler schedulers are auto-registered:
from neuraltrain import LightningOptimizer
optim_config = LightningOptimizer(
optimizer={"name": "Adam", "lr": 1e-4, "kwargs": {"weight_decay": 0.0}},
scheduler={"name": "OneCycleLR", "kwargs": {"max_lr": 3e-3, "pct_start": 0.2}},
)
print(optim_config)
optimizer=Adam(lr=0.0001, kwargs={'weight_decay': 0.0}) scheduler=OneCycleLR(kwargs={'max_lr': 0.003, 'pct_start': 0.2}) interval='step'
build() is called inside BrainModule.configure_optimizers()
with the model parameters and scheduler kwargs:
def configure_optimizers(self):
try:
return self.optim_config.build(
self.parameters(),
total_steps=self.trainer.estimated_stepping_batches,
)
except TypeError:
return self.optim_config.build(self.parameters())
All these objects serialise cleanly to JSON for experiment snapshots, and can be swept just like model configs:
grid = {
"loss.name": ["CrossEntropyLoss", "ClipLoss"],
"optim.optimizer.lr": [1e-3, 1e-4],
}
Total running time of the script: (0 minutes 0.004 seconds)