neuraltrain.models.fmri_mlp.FmriLinearModel

class neuraltrain.models.fmri_mlp.FmriLinearModel(in_dim: int, out_dim: int, config: FmriLinear | None = None)[source][source]

nn.Module implementation of FmriLinear.

forward(x: Tensor, subject_ids: Tensor | None = None) Tensor | dict[str, Tensor][source][source]

Forward pass through the linear model.

Parameters:
  • x (Tensor) – Input of shape (B, F, T).

  • subject_ids (Tensor or None) – Per-example subject indices (unused, kept for API consistency).