Source code for xformers.factory.model_factory

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch

from xformers.components import reversible as rv
from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients
from xformers.factory.block_configs import (
from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock
from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit

logger = logging.getLogger("xformers")

[docs]@dataclass(init=False) class xFormerConfig: """ The configuration structure to define a full Transformer. This can include a stack of encoder layers, and a stack of decoder layers. It is optionally possible to share the embedding weights in between the encoder and decoder positional encoding, as proposed for instance by `Using the Output Embedding to Improve Language Models`, Press et al. A full config example is for instance as follows: :: xformer_config = [ { "reversible": False, # Turn on to test the effect of using reversible layers "block_type": "encoder", "num_layers": LAYERS, "dim_model": EMB, "residual_norm_style": "pre", "position_encoding_config": { "name": "vocab", "seq_len": CONTEXT, "vocab_size": VOCAB_SIZE, }, "multi_head_config": { "num_heads": NUM_HEADS, "residual_dropout": RES_DROP, "use_rotary_embeddings": True, "attention": { "name": ATTENTION_MECHANISM_STR, "dropout": ATTN_DROP, "causal": True, "seq_len": CONTEXT, }, }, "feedforward_config": { "name": "FusedMLP", # Use MLP if Triton is not available "dropout": MLP_DROP, "activation": "gelu", "hidden_layer_multiplier": MLP_MULTIPLIER, }, } ] .. _`Using the Output Embedding to Improve Language Models`: """ stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]] tie_embedding_weights: bool = False weight_init: xFormerWeightInit = xFormerWeightInit.ViT def __init__( self, stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]], tie_embedding_weights: bool = False, weight_init: xFormerWeightInit = xFormerWeightInit.ViT, ): # Type all the configurations. Possible typos are caught here if isinstance(stack_configs, dict): self.stack_configs = {} for k, config in stack_configs.items(): if config["block_type"] == "encoder": self.stack_configs[k] = xFormerEncoderConfig(**config) else: self.stack_configs[k] = xFormerDecoderConfig(**config) else: self.stack_configs = [] for config in stack_configs: if config["block_type"] == "encoder": self.stack_configs.append(xFormerEncoderConfig(**config)) else: self.stack_configs.append(xFormerDecoderConfig(**config)) self.tie_embedding_weights = tie_embedding_weights self.weight_init = weight_init
[docs]class xFormer(torch.nn.Module):
[docs] def __init__( self, stack_configs: Union[ xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig] ], tie_embedding_weights: bool = False, weight_init: xFormerWeightInit = xFormerWeightInit.ViT, ): """ Given a serialized configuration, generate the corresponding model. This is only a helper and can easily be bypassed """ super().__init__() if isinstance(stack_configs, Dict): stack_configs = list(stack_configs.values()) # Convenience, users can pass either a list of configs or a single one if not isinstance(stack_configs, List): stack_configs = [stack_configs] # Sanity checks, some config combinations do not make sense self._verify_reversible(stack_configs) self._verify_deepnorm(stack_configs) encoders: List[torch.nn.Module] = [] decoders: List[torch.nn.Module] = [] self.reversible_encoder = False self.rev_enc_pose_encoding = None # Unroll the configs and build the model for config in stack_configs: # Handle either Encoder or Decoder stacks builder = ( xFormerEncoderBlock.from_config if isinstance(config, xFormerEncoderConfig) else xFormerDecoderBlock.from_config ) recipient = ( encoders if isinstance(config, xFormerEncoderConfig) else decoders ) # Build up the stack for i in range(config.num_layers): # Label where this layer is in the stack # (for instance useful for the positional encoding, or late layer norm) if len(recipient) > 0: config.layer_position.mark_not_first() if config != stack_configs[-1] or i < config.num_layers - 1: config.layer_position.mark_not_last() block = builder(config) # type: ignore # If reversible: extract the reversible sub-parts, else append the block as-is if config.reversible: # WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance) assert isinstance(config, xFormerEncoderConfig) if block.pose_encoding is not None: self.rev_enc_pose_encoding = block.pose_encoding self.reversible_encoder = True f, g = xFormerEncoderBlock.get_reversible_layer(config) recipient.append(torch.nn.ModuleList([f, g])) else: recipient.append(block) # type: ignore # Tie embedding weights, if requested and possible assert ( not tie_embedding_weights or not self.reversible_encoder ), "Reversible layers and tied embeddings is not supported for now" if ( tie_embedding_weights and encoders and encoders[0].pose_encoding and decoders and decoders[0].pose_encoding and not config.reversible ):"Tying encoder and decoder embeddings, as requested") encoders[0].pose_encoding = decoders[0].pose_encoding self.encoders: torch.nn.Module = ( rv.ReversibleSequence(torch.nn.ModuleList(encoders)) if self.reversible_encoder else torch.nn.ModuleList(encoders) ) self.decoders = torch.nn.ModuleList(decoders) use_deepnorm = ( stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm ) assert ( not use_deepnorm or not self.reversible_encoder ), "Reversible layers and deepnorm is not supported for now" self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm)
[docs] @classmethod def from_config(cls, config: xFormerConfig): return cls( config.stack_configs, config.tie_embedding_weights, config.weight_init )
def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]): reversible = [ c.reversible for c in filter(lambda x: x.block_type == "encoder", stack_configs) ] assert all(reversible) or not any(reversible), ( "All layers need to have the same reversibility setting. " + f"Currently {reversible}" ) def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]): deepnorm = [ c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs ] assert all(deepnorm) or not any(deepnorm), ( "All layers need to have the same deepnorm setting. " + f"Currently {deepnorm}" )
[docs] def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool): # The deepnorm weight initialization method requires different gain factors for the encoder # and decoder, depending on the general model structure (number of respective layers) if use_deep_norm: encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients( encoder_layers=len(self.encoders), decoder_layers=len(self.decoders) # type: ignore ) else: encoder_coefficients, decoder_coefficients = None, None encoder_gain = ( encoder_coefficients.beta if encoder_coefficients is not None else 1.0 ) decoder_gain = ( decoder_coefficients.beta if decoder_coefficients is not None else 1.0 ) # Pick the desired init function init_fn = get_weight_init_fn(weight_init) # Initialize all the encoder weights for name, module in self.encoders.named_children(): init_fn(module=module, name=name, gain=encoder_gain) for name, module in self.decoders.named_children(): init_fn(module=module, name=name, gain=decoder_gain)
[docs] def forward( self, src: torch.Tensor, tgt: Optional[torch.Tensor] = None, encoder_input_mask: Optional[torch.Tensor] = None, decoder_input_mask: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: # Encode to latent space if encoder is present if len(list(self.encoders.parameters())) > 0: encoders = self.encoders memory = src.clone() if isinstance(encoders, torch.nn.ModuleList): for encoder in encoders: memory = encoder(memory, input_mask=encoder_input_mask) else: if self.rev_enc_pose_encoding: memory = self.rev_enc_pose_encoding(src) # Reversible Encoder x =[memory, memory], dim=-1) # Apply the optional input masking if encoder_input_mask is not None: if x.dim() - encoder_input_mask.dim() > 1: encoder_input_mask.unsqueeze(0) x += encoder_input_mask.unsqueeze(-1) x = encoders(x) memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0) if not self.decoders: return memory # If decoder: either use the encoder ouput, or just decode, both options are possible if len(self.decoders) > 0: tgt = src.clone() if tgt is None else tgt for decoder in self.decoders: tgt = decoder( target=tgt, # pyre-fixme[61]: `memory` is not always initialized here. memory=memory, input_mask=decoder_input_mask, ) return tgt return None