neuraltrain.models.preprocessor.OnTheFlyPreprocessorModel

class neuraltrain.models.preprocessor.OnTheFlyPreprocessorModel(config: OnTheFlyPreprocessor)[source][source]

nn.Module implementation of OnTheFlyPreprocessor.

forward(x: Tensor, channel_positions: Tensor | None = None, **kwargs: Any) tuple[Tensor, Tensor | None][source][source]

Preprocess the input tensor in-place (clone first).

Parameters:
  • x (Tensor) – Input of shape (B, C, T).

  • channel_positions (Tensor or None) – Electrode coordinates of shape (B, C, D). Updated in-place when update_ch_pos is True and bad channels are detected.

Returns:

Preprocessed tensor and (possibly updated) channel positions.

Return type:

tuple of (Tensor, Tensor or None)