neuraltrain.models.conv_transformer.ConvTransformer

pydantic model neuraltrain.models.conv_transformer.ConvTransformer[source][source]

Convolutional encoder followed by optional temporal aggregation and a transformer.

Parameters:
  • dim – Internal token dimension.

  • encoder_config – Configuration for the convolutional encoder.

  • temporal_downsampling_config – Configuration for the optional temporal downsampling module.

  • conv_pos_emb_kernel_size – If provided, use convolutional positional embedding with this kernel size.

  • neuro_device_types – List of expected neuro device types that can be used to embed the device type in the transformer.

  • add_cls_token – If True, add a [CLS] token to the input of the transformer.

  • pre_transformer_layer_norm – If True, apply layer normalization before the transformer.

  • transformer_config – Configuration for the transformer encoder.

  • output_avg_pool – If True, average the tokens outputted by the transformer.

  • output_layer_dim – Set to 0 for no output layer, or None to use the same dimension as the transformer. Of note, both Bendr and Wav2vec2.0 use an output linear projection though it’s not mentioned in their respective papers.

Fields:
field dim: int = 512[source]
field encoder_config: SimplerConv | SimpleConv [Required][source]
field temporal_downsampling_config: TemporalDownsampling | None = None[source]
field conv_pos_emb_kernel_size: int | None = None[source]
field neuro_device_types: list[str] | None = None[source]
field add_cls_token: bool = False[source]
field pre_transformer_layer_norm: bool = False[source]
field transformer_config: TransformerEncoder | Conformer | None = None[source]
field output_avg_pool: bool = False[source]
field output_layer_dim: int | None = 0[source]
build(n_in_channels: int, n_outputs: int | None = None) ConvTransformerModel[source][source]

Build ConvTransformer model.

Parameters:
  • n_in_channels – Number of input channels.

  • n_outputs – Number of output dimensions. If None, use the output_layer_dim parameter from the config.