Source code for neuraltrain.models.reve

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""Neuraltrain custom configuration for REVE.

Includes the following adaptations over the raw braindecode REVE model:

* **Channel name remapping** -- an explicit ``channel_mapping`` dict maps
  dataset channel names to names recognised by REVE's position bank.
  Channels whose names already appear in the bank need no entry.
* **Pretrained loading** -- REVE's ``__init__`` requires ``n_times`` and
  ``n_outputs`` to build its ``final_layer``, but
  :class:`BaseBrainDecodeModel.build` blocks ``n_times`` for pretrained
  models.  ``NtReve.build()`` calls ``from_pretrained`` directly.
* **Encoder-only support** -- when ``n_outputs`` is not provided (i.e. when
  a ``DownstreamWrapperModel`` handles the classification head), the wrapper
  calls ``forward(return_output=True)`` and extracts the final transformer
  layer's output, bypassing REVE's ``final_layer`` entirely.
"""

import logging
import typing as tp

import torch
import torch.nn as nn

from .base import BaseBrainDecodeModel
from .common import parse_bipolar_name

logger = logging.getLogger(__name__)


class _ReveWrapper(nn.Module):
    """Thin wrapper around REVE for channel subsetting and encoder-only output.

    Combines two optional adaptations in a single module:

    * **Channel subsetting** -- when *channel_indices* is not ``None``,
      the input EEG tensor is sliced along the channel dimension before
      forwarding to REVE.  No-op when ``None``.
    * **Encoder-only output** -- when *encoder_only* is ``True``, the
      forward call passes ``return_output=True`` to REVE and returns the
      final transformer layer output (index ``-1``), bypassing REVE's
      ``final_layer``.  When ``False``, forwards normally.
    """

    channel_indices: torch.Tensor | None

    def __init__(
        self,
        model: nn.Module,
        channel_indices: list[int] | None = None,
        encoder_only: bool = False,
    ):
        super().__init__()
        self.model = model
        self.encoder_only = encoder_only
        if channel_indices is not None:
            self.register_buffer(
                "channel_indices",
                torch.tensor(channel_indices, dtype=torch.long),
            )
        else:
            self.register_buffer("channel_indices", None)

    def forward(
        self, eeg: torch.Tensor, pos: torch.Tensor | None = None, **kwargs: tp.Any
    ) -> torch.Tensor:
        if self.channel_indices is not None:
            eeg = eeg[:, self.channel_indices]
        if self.encoder_only:
            return self.model(eeg, pos=pos, return_output=True, **kwargs)[-1]
        return self.model(eeg, pos=pos, **kwargs)


[docs] class NtReve(BaseBrainDecodeModel): """Config for the braindecode REVE model with channel-mapping support. Extends :class:`BaseBrainDecodeModel` with REVE-specific logic: 1. **Channel remapping** -- an explicit ``channel_mapping`` dict maps dataset channel names to REVE position-bank names. Channels whose names already appear in the bank (exact match) need no entry. 2. **Pretrained loading** -- bypasses the base-class restriction on ``n_times`` for pretrained models, since REVE needs it to size its ``final_layer``. 3. **Encoder-only output** -- when ``n_outputs`` is ``None`` (downstream wrapper handles the head), the model is wrapped to call ``forward(return_output=True)`` and return the final transformer layer output, bypassing REVE's ``final_layer``. Parameters ---------- channel_mapping : dict or None Explicit mapping from dataset channel names to REVE position-bank names. Useful for EEG systems whose naming convention is absent from the bank (e.g. Neuromag ``"EEG 005"`` or easycap-M10 numeric ``"2"``). """ _MODEL_CLASS: tp.ClassVar[tp.Any] = None # resolved lazily; see _ensure_model_class chs_info_required: tp.ClassVar[bool] = True needs_n_times: tp.ClassVar[bool] = True channel_mapping: dict[str, str] | None = None @classmethod def _ensure_model_class(cls) -> None: """Resolve ``_MODEL_CLASS`` on first use. Must also be called in ``build()`` because ``model_post_init`` is not invoked after submitit deserialization on SLURM workers. """ if cls._MODEL_CLASS is None: import braindecode.models cls._MODEL_CLASS = braindecode.models.REVE def model_post_init(self, __context__: tp.Any) -> None: type(self)._ensure_model_class() super().model_post_init(__context__) def _remap_chs_info( self, chs_info: list[dict[str, tp.Any]], ) -> list[dict[str, tp.Any]]: """Apply ``channel_mapping`` to *chs_info*, returning a new list.""" if not self.channel_mapping: return chs_info return [ {**ch, "ch_name": self.channel_mapping.get(ch["ch_name"], ch["ch_name"])} for ch in chs_info ] @staticmethod def _derive_bipolar_position( name: str, bank: tp.Any, ) -> torch.Tensor | None: """Look up the anode position for a bipolar channel name. Returns ``None`` when *name* is not a valid bipolar pair or the anode electrode is missing from *bank*. """ pair = parse_bipolar_name(name) if pair is None: return None anode = pair[0] if anode not in bank.mapping: return None return bank.embedding[bank.mapping[anode]] def build( self, n_chans: int | None = None, n_times: int | None = None, n_outputs: int | None = None, chs_info: list[dict[str, tp.Any]] | None = None, **kwargs: tp.Any, ) -> nn.Module: type(self)._ensure_model_class() if chs_info is not None: chs_info = self._remap_chs_info(chs_info) channel_indices: list[int] | None = None custom_positions: torch.Tensor | None = None if chs_info is not None: from braindecode.models.reve import RevePositionBank bank = RevePositionBank() n_original = len(chs_info) valid_indices: list[int] = [] valid_chs: list[dict[str, tp.Any]] = [] positions: list[torch.Tensor] = [] dropped: list[str] = [] n_derived = 0 for i, ch in enumerate(chs_info): name = ch["ch_name"] if name in bank.mapping: valid_indices.append(i) valid_chs.append(ch) positions.append(bank.embedding[bank.mapping[name]]) else: derived = self._derive_bipolar_position(name, bank) if derived is not None: valid_indices.append(i) valid_chs.append(ch) positions.append(derived) n_derived += 1 else: dropped.append(name) if dropped: logger.warning( "Dropping %d channel(s) not resolvable from REVE position bank: %s", len(dropped), dropped, ) if n_derived: logger.info( "Mapped %d bipolar channel(s) to REVE via anode fallback.", n_derived, ) if len(valid_chs) < n_original: channel_indices = valid_indices chs_info = valid_chs n_chans = len(chs_info) if not valid_chs: raise ValueError( "No dataset channels match the REVE position bank " "(directly or via anode fallback). " "Consider adding a `channel_mapping`." ) logger.info( "[REVE_CHANNELS] n_dataset=%d n_resolved=%d (n_derived=%d)", n_original, len(valid_chs), n_derived, ) if n_derived > 0: custom_positions = torch.stack(positions) build_kwargs: dict[str, tp.Any] = {} if n_chans is not None: build_kwargs["n_chans"] = n_chans if chs_info is not None and custom_positions is None: build_kwargs["chs_info"] = chs_info if n_times is not None: build_kwargs["n_times"] = n_times encoder_only = n_outputs is None if self.from_pretrained_name is not None: build_kwargs["n_outputs"] = n_outputs if n_outputs is not None else 2 elif n_outputs is not None: build_kwargs["n_outputs"] = n_outputs model = super().build(**build_kwargs, **kwargs) if custom_positions is not None: model.default_pos = custom_positions return _ReveWrapper(model, channel_indices, encoder_only)