Source code for neuraltrain.models.luna
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Neuraltrain custom configuration for LUNA.
Includes the following adaptations over the raw braindecode LUNA model:
* **Keyword mapping** — the data pipeline provides per-sample 3-D
electrode positions as ``channel_positions``, whereas braindecode's LUNA
forward signature expects ``channel_locations``. The wrapper transparently
maps between the two.
* **Time padding** — zero-pads the time dimension so it is a multiple of
LUNA's ``patch_size``.
* **Encoder-only output** — when ``n_outputs`` is not provided (i.e. the
downstream classification head is handled externally), the classification
head is replaced with ``nn.Identity()`` so the model returns the normalised
encoder latent of shape ``(B, n_patches, num_queries * embed_dim)``.
"""
import logging
import typing as tp
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import BaseBrainDecodeModel
logger = logging.getLogger(__name__)
class _LunaEncoderWrapper(nn.Module):
"""Wraps a braindecode ``LUNA`` to accept ``channel_positions`` instead of
``channel_locations`` and to zero-pad the time dimension.
"""
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
@property
def mapping(self) -> dict[str, str]:
"""Expose LUNA's key mapping, prefixed for the wrapper's state dict."""
inner_mapping = getattr(self.model, "mapping", None) or {}
return {k: f"model.{v}" for k, v in inner_mapping.items()}
@staticmethod
def _pad_to_patch_size(X: torch.Tensor, patch_size: int) -> torch.Tensor:
"""Zero-pad the time dimension to the next multiple of *patch_size*."""
remainder = X.shape[-1] % patch_size
if remainder == 0:
return X
return F.pad(X, (0, patch_size - remainder))
def forward(
self,
X: torch.Tensor,
channel_positions: torch.Tensor | None = None,
) -> torch.Tensor:
X = self._pad_to_patch_size(X, self.model.patch_size) # type: ignore[union-attr,arg-type]
return self.model(X, channel_locations=channel_positions)
[docs]
class NtLuna(BaseBrainDecodeModel):
"""Config for the braindecode LUNA model.
Extends :class:`BaseBrainDecodeModel` with LUNA-specific logic:
1. **Keyword mapping** — wraps the model so its forward accepts
``channel_positions`` and maps it to LUNA's ``channel_locations``.
2. **Time padding** — zero-pads the time dimension to a multiple of
``patch_size``.
3. **Encoder-only output** — when ``n_outputs`` is not passed (i.e. when
a ``DownstreamWrapperModel`` handles the classification head), the
classification head is replaced with ``nn.Identity()`` so the model
returns the encoder latent.
Parameters
----------
pretrained_filename : str or None
When ``from_pretrained_name`` points to a Hub repository containing
multiple weight files (e.g. ``PulpBio/LUNA``), this selects which
file to download. Requires braindecode >= 1.5 which natively
supports the ``filename`` kwarg in ``from_pretrained``.
"""
_MODEL_CLASS: tp.ClassVar[tp.Any] = None # resolved lazily; see _ensure_model_class
pretrained_filename: str | None = None
@classmethod
def _ensure_model_class(cls) -> None:
"""Resolve ``_MODEL_CLASS`` on first use.
Must also be called in ``build()`` because ``model_post_init`` is not
invoked after submitit deserialization on SLURM workers.
"""
if cls._MODEL_CLASS is None:
import braindecode.models
cls._MODEL_CLASS = braindecode.models.LUNA
def model_post_init(self, __context__: tp.Any) -> None:
type(self)._ensure_model_class()
super().model_post_init(__context__)
def build(
self,
n_chans: int | None = None,
n_times: int | None = None,
n_outputs: int | None = None,
**kwargs: tp.Any,
) -> nn.Module:
type(self)._ensure_model_class()
build_kwargs: dict[str, tp.Any] = {}
if n_chans is not None:
build_kwargs["n_chans"] = n_chans
if n_times is not None:
build_kwargs["n_times"] = n_times
encoder_only = n_outputs is None
build_kwargs["n_outputs"] = 1 if encoder_only else n_outputs
if self.from_pretrained_name is not None and self.pretrained_filename is not None:
kwargs["filename"] = self.pretrained_filename
model = super().build(**build_kwargs, **kwargs)
if encoder_only:
model.final_layer = nn.Identity()
return _LunaEncoderWrapper(model)