Source code for neuraltrain.models.transformer

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Wrappers around Transformer models from x_transformers package.
"""

import logging

import torch.nn as nn

from .base import BaseModelConfig

logger = logging.getLogger(__name__)


[docs] class TransformerEncoder(BaseModelConfig): """Transformer encoder/decoder built on top of ``x_transformers``. Parameters ---------- heads : int Number of attention heads. depth : int Number of Transformer layers. cross_attend : bool Enable cross-attention (decoder mode). causal : bool If True, build a causal ``Decoder`` instead of an ``Encoder``. attn_flash : bool Use Flash Attention. Not compatible with ALiBi. attn_dropout : float Dropout probability inside the attention layers. ff_mult : int Feed-forward expansion factor (``ff_dim = dim * ff_mult``). ff_dropout : float Dropout probability in the feed-forward layers. use_scalenorm : bool Use ScaleNorm instead of LayerNorm. use_rmsnorm : bool Use RMSNorm instead of LayerNorm. rel_pos_bias : bool Use relative positional bias. alibi_pos_bias : bool Use ALiBi positional bias. rotary_pos_emb : bool Use rotary positional embeddings. rotary_xpos : bool Use xPos extension for rotary embeddings. residual_attn : bool Add residual connections around the attention output. scale_residual : bool Scale residual connections. layer_dropout : float Probability of dropping an entire Transformer layer during training. """ heads: int = 8 depth: int = 12 # Attention blocks cross_attend: bool = False causal: bool = False # Use Flash Attention; not compatible with ALiBi and probably other extractors attn_flash: bool = False attn_dropout: float = 0.1 # Feedforward blocks ff_mult: int = 4 # Feedforward expansion factor ff_dropout: float = 0.0 # Normalization use_scalenorm: bool = True use_rmsnorm: bool = False # Positional embedding rel_pos_bias: bool = False alibi_pos_bias: bool = False rotary_pos_emb: bool = True rotary_xpos: bool = False # Others residual_attn: bool = False scale_residual: bool = True layer_dropout: float = 0.0 def build(self, dim: int) -> nn.Module: from x_transformers import Decoder, Encoder # type: ignore if dim % self.heads != 0: raise ValueError( f"dim ({dim}) must be divisible by the number of heads ({self.heads})" ) if self.rotary_pos_emb and dim // self.heads < 32: raise ValueError( f"dim_head ({dim // self.heads}) < 32: x-transformers clamps the rotary " f"embedding dimension to min 32, causing a shape mismatch. " f"Increase dim or reduce heads so that dim // heads >= 32, " f"or disable rotary_pos_emb." ) kwargs = self.model_dump() kwargs["attn_dim_head"] = dim // self.heads del kwargs["name"] del kwargs["causal"] if self.causal: return Decoder(dim=dim, **kwargs) else: return Encoder(dim=dim, **kwargs)