neuraltrain.models.fmri_mlp.FmriMlpModel

class neuraltrain.models.fmri_mlp.FmriMlpModel(in_dim: int, out_dim: int, config: FmriMlp | None = None)[source][source]

Residual MLP adapted from [1].

See https://github.com/MedARC-AI/fMRI-reconstruction-NSD/blob/main/src/models.py#L171

References

[1] Scotti, Paul, et al. “Reconstructing the mind’s eye: fMRI-to-image with contrastive

learning and diffusion priors.” Advances in Neural Information Processing Systems 36 (2024).

forward(x: Tensor, subject_ids: Tensor | None = None) Tensor | dict[str, Tensor][source][source]

Forward pass through the residual MLP.

Parameters:
  • x (Tensor) – Input of shape (B, F, T) or (B, ..., T).

  • subject_ids (Tensor or None) – Per-example subject indices, shape (B,).