Source code for neuraltrain.models.linear

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

from torch import nn

from .base import BaseModelConfig
from .common import SubjectLayers


[docs] class Linear(BaseModelConfig): """Simple linear projection, with optional per-subject weights. Parameters ---------- reduction : {"mean", "concat"} How to reduce the time dimension before the linear layer. ``"mean"`` averages over time; ``"concat"`` flattens channels and time. subject_layers_config : SubjectLayers or None If set, use a :class:`SubjectLayersModel` instead of a shared ``nn.Linear``. """ reduction: tp.Literal["mean", "concat"] = "mean" subject_layers_config: SubjectLayers | None = None def build(self, n_in_channels: int, n_outputs: int) -> nn.Module: return LinearModel( n_in_channels, n_outputs, reduction=self.reduction, subject_layers_config=self.subject_layers_config, )
[docs] class LinearModel(nn.Module): """``nn.Module`` implementation of :class:`Linear`.""" def __init__( self, n_in_channels: int, n_outputs: int, reduction: str = "mean", subject_layers_config: SubjectLayers | None = None, ): super().__init__() self.n_in_channels = n_in_channels self.n_outputs = n_outputs self.linear: tp.Any if subject_layers_config is not None: self.linear = subject_layers_config.build(n_in_channels, n_outputs) else: self.linear = nn.Linear(n_in_channels, n_outputs) self.reduction = reduction
[docs] def forward(self, x, subject_id=None): """Forward pass: reduce time, then project. Parameters ---------- x : Tensor Input of shape ``(B, C, T)`` or ``(B, C)``. subject_id : Tensor or None Per-example subject indices, shape ``(B,)``. """ if len(x.shape) > 2: if self.reduction == "concat": x = x.view(x.size(0), -1) elif self.reduction == "mean": x = x.mean(dim=-1) if isinstance(self.linear, nn.Linear): x = self.linear(x) else: x = self.linear(x.unsqueeze(-1), subject_id).squeeze(-1) return x