neuraltrain.models.transformer.TransformerEncoder

pydantic model neuraltrain.models.transformer.TransformerEncoder[source][source]

Transformer encoder/decoder built on top of x_transformers.

Parameters:
  • heads (int) – Number of attention heads.

  • depth (int) – Number of Transformer layers.

  • cross_attend (bool) – Enable cross-attention (decoder mode).

  • causal (bool) – If True, build a causal Decoder instead of an Encoder.

  • attn_flash (bool) – Use Flash Attention. Not compatible with ALiBi.

  • attn_dropout (float) – Dropout probability inside the attention layers.

  • ff_mult (int) – Feed-forward expansion factor (ff_dim = dim * ff_mult).

  • ff_dropout (float) – Dropout probability in the feed-forward layers.

  • use_scalenorm (bool) – Use ScaleNorm instead of LayerNorm.

  • use_rmsnorm (bool) – Use RMSNorm instead of LayerNorm.

  • rel_pos_bias (bool) – Use relative positional bias.

  • alibi_pos_bias (bool) – Use ALiBi positional bias.

  • rotary_pos_emb (bool) – Use rotary positional embeddings.

  • rotary_xpos (bool) – Use xPos extension for rotary embeddings.

  • residual_attn (bool) – Add residual connections around the attention output.

  • scale_residual (bool) – Scale residual connections.

  • layer_dropout (float) – Probability of dropping an entire Transformer layer during training.

Fields:
field heads: int = 8[source]
field depth: int = 12[source]
field cross_attend: bool = False[source]
field causal: bool = False[source]
field attn_flash: bool = False[source]
field attn_dropout: float = 0.1[source]
field ff_mult: int = 4[source]
field ff_dropout: float = 0.0[source]
field use_scalenorm: bool = True[source]
field use_rmsnorm: bool = False[source]
field rel_pos_bias: bool = False[source]
field alibi_pos_bias: bool = False[source]
field rotary_pos_emb: bool = True[source]
field rotary_xpos: bool = False[source]
field residual_attn: bool = False[source]
field scale_residual: bool = True[source]
field layer_dropout: float = 0.0[source]
build(dim: int) Module[source][source]