Note
Go to the end to download the full example code.
Model Configs¶
neuraltrain represents models as pydantic configs. The config is what gets
stored in your experiment; the actual torch.nn.Module is built later, once
the input and output dimensions are known.
This tutorial covers:
the config-first model pattern
instantiating and building a model from config
available model families and shared building blocks
sweeping over model hyperparameters
Config first, module later¶
The central abstraction is BaseModelConfig. Each subclass is a typed
pydantic model that knows how to build() its nn.Module when runtime
dimensions are available.
In the example project, the model is built inside
Experiment._build_brain_module(), once the first batch reveals the
input shape:
Let’s instantiate the default model config (SimpleConvTimeAgg, a
1-D convolutional encoder with temporal aggregation):
from neuraltrain import BaseModelConfig, models
model_config = models.SimpleConvTimeAgg(hidden=32, depth=4, merger_config=None)
print(model_config)
hidden=32 depth=4 linear_out=False complex_out=False kernel_size=5 growth=1.0 dilation_growth=2 dilation_period=None skip=False post_skip=False scale=None rewrite=False groups=1 glu=0 glu_context=0 glu_glu=True gelu=False dropout=0.0 dropout_rescale=True conv_dropout=0.0 dropout_input=0.0 batch_norm=False relu_leakiness=0.0 transformer_config=None subject_layers_config=None subject_layers_dim='hidden' merger_config=None initial_linear=0 initial_depth=1 initial_nonlin=False backbone_out_channels=None time_agg_out='gap' n_time_groups=None output_head_config=None
Building the model¶
Calling build() with the runtime dimensions produces a concrete
nn.Module that we can inspect:
SimpleConvTimeAggModel(
(encoder): ConvSequence(
(sequence): ModuleList(
(0): Sequential(
(0): Conv1d(204, 32, kernel_size=(5,), stride=(1,), padding=(2,))
(1): ReLU()
)
(1): Sequential(
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(4,), dilation=(2,))
(1): ReLU()
)
(2): Sequential(
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(8,), dilation=(4,))
(1): ReLU()
)
(3): Sequential(
(0): Conv1d(32, 4, kernel_size=(5,), stride=(1,), padding=(16,), dilation=(8,))
)
)
(glus): ModuleList(
(0-3): 4 x None
)
)
(time_agg_out): AdaptiveAvgPool1d(output_size=1)
)
Sweeping model configs¶
Because models are typed configs, you can sweep architecture parameters in a grid alongside everything else:
grid = {
"brain_model_config.hidden": [32, 64],
"brain_model_config.depth": [2, 4],
}
This is one of the main advantages of keeping the model choice in a serialisable configuration object.
Total running time of the script: (0 minutes 0.110 seconds)