Source code for neuraltrain.models.base
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Pydantic configurations for models."""
import typing as tp
import pydantic
from exca import helpers
from torch import nn
[docs]
class BaseModelConfig(helpers.DiscriminatedModel, discriminator_key="name"):
"""Base class for model configurations."""
def build(self, *args, **kwargs) -> nn.Module:
raise NotImplementedError
# Base class for braindecode model configs (using kwargs pattern)
[docs]
class BaseBrainDecodeModel(BaseModelConfig):
"""Base class for braindecode model configurations.
Attributes
----------
kwargs : dict
Free-form keyword arguments forwarded to the braindecode model
constructor. Validated against the model's ``__init__`` signature at
config creation time.
from_pretrained_name : str or None
Optional HuggingFace Hub repository ID (e.g.
``"braindecode/labram-pretrained"``). When set, ``build()`` calls
``_MODEL_CLASS.from_pretrained()`` instead of the regular constructor.
"""
_MODEL_CLASS: tp.ClassVar[tp.Any]
kwargs: dict[str, tp.Any] = {}
from_pretrained_name: str | None = None
def model_post_init(self, __context__: tp.Any) -> None:
super().model_post_init(__context__)
helpers.validate_kwargs(self._MODEL_CLASS, self.kwargs)
def build(self, **kwargs: tp.Any) -> nn.Module:
if overlap := set(self.kwargs) & set(kwargs):
raise ValueError(
f"Build kwargs overlap with config kwargs for keys: {overlap}."
)
kwargs = self.kwargs | kwargs
if self.from_pretrained_name is not None:
return self._MODEL_CLASS.from_pretrained(self.from_pretrained_name, **kwargs)
return self._MODEL_CLASS(**kwargs) # type: ignore
def _register_braindecode_models() -> None:
"""Register per-model config classes for all braindecode models.
Called at import time only when braindecode is installed.
"""
import braindecode.models
from braindecode.models import __all__ as bd_models
for name in bd_models:
cls: type[BaseBrainDecodeModel] = pydantic.create_model( # type: ignore[assignment]
name,
__base__=BaseBrainDecodeModel,
)
cls._MODEL_CLASS = getattr(braindecode.models, name) # type: ignore[attr-defined]
globals()[name] = cls
try:
_register_braindecode_models()
except ImportError:
pass