Source code for neuraltrain.models.base
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Pydantic configurations for models."""
import typing as tp
import pydantic
from exca import helpers
from torch import nn
[docs]
class BaseModelConfig(helpers.DiscriminatedModel, discriminator_key="name"):
"""Base class for model configurations."""
[docs]
def build(self, *args, **kwargs) -> nn.Module:
raise NotImplementedError
# Base class for braindecode model configs (using kwargs pattern)
[docs]
class BaseBrainDecodeModel(BaseModelConfig):
"""Base class for braindecode model configurations.
Subclasses set ``_MODEL_CLASS_PATH`` (e.g.
``"braindecode.models.Labram"``) to resolve the underlying class lazily,
avoiding an unconditional braindecode import at module load time.
Subclasses that need custom resolution (e.g. optional-dependency
handling) can instead override ``_ensure_model_class`` directly.
The dynamic registration in :func:`_register_braindecode_models` sets
``_MODEL_CLASS`` directly at import time for the common braindecode
models, which short-circuits the lazy path.
Attributes
----------
kwargs : dict
Free-form keyword arguments forwarded to the braindecode model
constructor. Validated against the model's ``__init__`` signature at
config creation time.
from_pretrained_name : str or None
Optional HuggingFace Hub repository ID (e.g.
``"braindecode/labram-pretrained"``). When set, ``build()`` calls
``_MODEL_CLASS.from_pretrained()`` instead of the regular constructor.
"""
_MODEL_CLASS: tp.ClassVar[tp.Any] = None
_MODEL_CLASS_PATH: tp.ClassVar[str | None] = None
kwargs: dict[str, tp.Any] = {}
from_pretrained_name: str | None = None
@classmethod
def _ensure_model_class(cls) -> None:
"""Resolve ``_MODEL_CLASS`` on first use.
Called from both ``model_post_init`` and ``build`` because
submitit deserialization on SLURM workers does not invoke
``model_post_init``.
"""
if cls._MODEL_CLASS is not None:
return
if cls._MODEL_CLASS_PATH is None:
raise RuntimeError(
f"{cls.__name__} has neither `_MODEL_CLASS` nor `_MODEL_CLASS_PATH` set."
)
import importlib
module_name, attr = cls._MODEL_CLASS_PATH.rsplit(".", 1)
cls._MODEL_CLASS = getattr(importlib.import_module(module_name), attr)
def model_post_init(self, __context__: tp.Any) -> None:
type(self)._ensure_model_class()
super().model_post_init(__context__)
helpers.validate_kwargs(self._MODEL_CLASS, self.kwargs)
[docs]
def build(self, **kwargs: tp.Any) -> nn.Module:
type(self)._ensure_model_class()
if overlap := set(self.kwargs) & set(kwargs):
raise ValueError(
f"Build kwargs overlap with config kwargs for keys: {overlap}."
)
kwargs = self.kwargs | kwargs
if self.from_pretrained_name is not None:
return self._MODEL_CLASS.from_pretrained(self.from_pretrained_name, **kwargs)
return self._MODEL_CLASS(**kwargs) # type: ignore
def _register_braindecode_models() -> None:
"""Register per-model config classes for all braindecode models.
Called at import time only when braindecode is installed.
"""
import braindecode.models
from braindecode.models import __all__ as bd_models
for name in bd_models:
cls: type[BaseBrainDecodeModel] = pydantic.create_model( # type: ignore[assignment]
name,
__base__=BaseBrainDecodeModel,
)
cls._MODEL_CLASS = getattr(braindecode.models, name) # type: ignore[attr-defined]
globals()[name] = cls
try:
_register_braindecode_models()
except ImportError:
pass