Add Your Own Model

What you will learn
  • How to configure a model

  • How to register a model architecture

  • How to use model factories to create models

  • How to use model loaders to load models

Prerequisites

Overview

The model configuration and loading system in fairseq2 consists of several key components:

  1. Model Config

    • Defines the architecture and hyperparameters of a model (e.g. number of layers, hidden size, learning rate, etc.)

  2. Architecture Registry

    • Stores predefined model architectures (e.g. base, large, small, etc.)

  3. Model Factory

    • Creates model instances from configs

  4. Model Loader

    • Handles model instantiation, checkpoint loading and format conversion (e.g. loading from fairseq2 checkpoint, converting from HF checkpoint, etc.)

Directory Layout

The directory structure for a typical fairseq2 model looks like this:

fairseq2/models/
├── your_model/
│   ├── __init__.py
│   ├── archs.py        # Defines model architectures   ├── factory.py      # Contains model factory and config classes   ├── loader.py       # Handles model loading and checkpoint conversion   └── model.py        # Actual model implementation

Note

The actual layout might vary depending on your implementation.

Step-by-Step Guide

1. Define Model Configuration

First, create a configuration class in factory.py:

from dataclasses import dataclass
from fairseq2.typing import DataType
from fairseq2.data import VocabularyInfo

@dataclass(kw_only=True)
class YourModelConfig:
    """Configuration for YourModel."""
    # Basic model parameters
    model_dim: int = 512
    """The dimensionality of the model."""

    num_layers: int = 6
    """The number of layers in the model."""

    num_heads: int = 8
    """The number of attention heads in the model."""

    ...

In the same file, create a registry for the model config:

your_model_config_registry = ConfigRegistry[YourModelConfig]()

your_model_arch = your_model_config_registry.decorator

This your_model_arch is a decorator that can be later used to register model architectures.

2. Register Model Architectures

Create an architecture registry and define standard architectures in archs.py:

from fairseq2.models.your_model.factory import your_model_arch

@your_model_arch("base")
def _base() -> YourModelConfig:
    """Base architecture."""
    return YourModelConfig()

@your_model_arch("large")
def _large() -> YourModelConfig:
    """Large architecture."""
    config = YourModelConfig()
    config.model_dim = 1024
    config.num_layers = 12
    config.num_heads = 16
    return config

Note

Keep the architecture names descriptive and simple. Document differences between architectures.

Some real-world examples
  • Base Transformer Architecture

The base Transformer model provides a foundation that other models can build upon:

# In transformer/archs.py
from fairseq2.models.transformer.factory import TransformerConfig, transformer_arch

@transformer_arch("base")
def _base() -> TransformerConfig:
    """Base architecture with default parameters."""
    return TransformerConfig()

@transformer_arch("big")
def _big() -> TransformerConfig:
    """Larger architecture with modified parameters."""
    config = TransformerConfig()
    config.model_dim = 1024
    config.num_encoder_attn_heads = 16
    config.num_decoder_attn_heads = 16
    config.ffn_inner_dim = 4096
    config.dropout_p = 0.3
    return config
  • NLLB (No Language Left Behind)

NLLB extends the base Transformer architecture with specific configurations for multilingual translation:

# In nllb/archs.py
@transformer_arch("nllb_dense_600m")
def _dense_600m() -> TransformerConfig:
    config = _dense_1b()  # Inherits from larger architecture

    # Modify for smaller model
    config.num_encoder_layers = 12
    config.num_decoder_layers = 12
    config.ffn_inner_dim = 1024 * 4

    return config

@transformer_arch("nllb_dense_1b")
def _dense_1b() -> TransformerConfig:
    config = transformer_archs.get("base")  # Start from base transformer

    # Customize for NLLB
    config.model_dim = 1024
    config.vocab_info = VocabularyInfo(
        size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0
    )
    config.num_encoder_layers = 24
    config.num_decoder_layers = 24
    config.num_encoder_attn_heads = 16
    config.num_decoder_attn_heads = 16
    config.ffn_inner_dim = 1024 * 8
    config.norm_order = TransformerNormOrder.PRE

return config
  • LLaMA Architecture

LLaMA introduces its own configuration class with specific parameters for large language models:

# In llama/archs.py
@llama_arch("7b")
def _7b() -> LLaMAConfig:
    """7B parameter model."""
    return LLaMAConfig()  # Uses default parameters

@llama_arch("13b")
def _13b() -> LLaMAConfig:
    """13B parameter model."""
    config = _7b()
    config.model_dim = 5120
    config.num_attn_heads = 40
    config.num_key_value_heads = 40
    config.ffn_inner_dim = 5120 * 4
    return config

@llama_arch("llama2_70b")
def _llama2_70b() -> LLaMAConfig:
    """LLaMA 2 70B parameter model."""
    config = _65b()
    config.max_seq_len = 4096
    config.num_key_value_heads = 8
    config.ffn_inner_dim = int(8192 * 4 * 1.3)  # See A.2.1 in LLaMA 2
    config.ffn_inner_dim_to_multiple = 4096
    return config

3. Create Model Factory

Implement a factory function in factory.py that creates model instances:

def create_your_model(config: YourModelConfig) -> YourModel:
    """Create a model instance from config."""
    model = YourModel(
        model_dim=config.model_dim,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        dropout_p=config.dropout_p,
        vocab_info=config.vocab_info,
    )

    # Convert to specified dtype
    model.to(dtype=config.dtype)

    return model
Some real-world examples
  • LLaMA Model Factory

We will use the fairseq2.models.llama.factory.create_llama_model function as an example.

The create_llama_model function serves as a factory method for instantiating a LLaMA model. It encapsulates the process of building a model with the LLaMABuilder class, which constructs various components of the model based on the provided configuration. This design pattern allows for a clean separation of model creation logic, making it easier for users to customize and extend the model architecture.

# In llama/factory.py
class LLaMABuilder:
...

def build_model(self) -> TransformerDecoderModel:
    """Build a model."""
    decoder_frontend = self.build_decoder_frontend()

    decoder = self.build_decoder()

    final_proj = Linear(...)

    model = TransformerDecoderModel(
        decoder_frontend,
        decoder,
        final_proj,
        ...
    )

    model.set_family(LLAMA_FAMILY)

    return model


def create_llama_model(
    config: LLaMAConfig,
    *,
    device: Device | None = None,
    dtype: DataType | None = None,
) -> TransformerDecoderModel:
    """Create a LLaMA model."""
    return LLaMABuilder(config, device=device, dtype=dtype).build_model()


model_factories.register(LLAMA_FAMILY, create_llama_model, LLaMAConfig, llama_archs)

create_llama_model instantiates your builder class and call the build_model method that actually creates the model as a TransformerDecoderModel. Don’t forget to register your model with the fairseq2 model factories so that it can be easily instantiated later.

4. Set Up Model Loader

Create a loader in loader.py that handles model instantiation and checkpoint loading:

from fairseq2.models.config_loader import StandardModelConfigLoader
from fairseq2.models.loader import StandardModelLoader, load_model

# Create config loader
load_your_model_config = StandardModelConfigLoader(
    YOUR_MODEL_FAMILY,
    YourModelConfig,
    your_model_archs
)

def convert_your_model_checkpoint(
    checkpoint: dict[str, Any], config: YourModelConfig
) -> dict[str, Any]:
    """Convert external checkpoints to fairseq2 format."""
    # Add checkpoint conversion logic here
    return {"model": checkpoint}

# Create model loader
load_your_model = StandardModelLoader(
    config_loader=load_your_model_config,
    factory=create_your_model,
    checkpoint_converter=convert_your_model_checkpoint,
)

# Register loader with global registry
load_model.register(YOUR_MODEL_FAMILY, load_your_model)
Some real-world examples on ckpt conversion

The convert_your_model_checkpoint function is a checkpoint converter that converts external checkpoints to fairseq2 format. For example, in Mistral, the checkpoint format is different from fairseq2’s.

# In mistral/loader.py
def convert_mistral_checkpoint(
    checkpoint: dict[str, Any], config: MistralConfig
) -> dict[str, Any]:
    """Convert Mistral checkpoint to fairseq2 format."""
    if "model" in checkpoint:  # Already in fairseq2 format
        return checkpoint

    # Map parameter names from Mistral to fairseq2 format
    key_map = {
        r"^layers\.([0-9]+)\.attention\.wq\.":    r"decoder.layers.\1.self_attn.q_proj.",
        r"^layers\.([0-9]+)\.attention\.wk\.":    r"decoder.layers.\1.self_attn.k_proj.",
        r"^layers\.([0-9]+)\.attention\.wv\.":    r"decoder.layers.\1.self_attn.v_proj.",
        # ... more mappings
    }

    checkpoint = convert_model_state_dict(checkpoint, key_map)
    return {"model": checkpoint}

Overall, to support loading from different checkpoint formats:

  1. Modify the checkpoint converter function

  2. Add mapping logic for different parameter names

  3. Handle any necessary tensor transformations

Advanced topic: Sharding

The sharder argument in StandardModelLoader is a function that shards the model, which is useful for distributed training. This is natively supported by fairseq2, so you don’t need to implement it yourself. For example, in LLaMA, the shard_llama_model function shards the model across multiple devices:

# In llama/loader.py
from fairseq2.models.transformer import shard_transformer_decoder_model
from fairseq2.models.loader import StandardModelLoader

def shard_llama_model(
    model: TransformerDecoderModel, config: LLaMAConfig, gangs: Mapping[str, Gang]
) -> None:
    gang = gangs["tp"]  # tensor parallel

    shard_embed_dim = config.max_seq_len < 8192  # LLaMA 1 or 2

    shard_transformer_decoder_model(model, gang, shard_embed_dim=shard_embed_dim)


load_llama_model = StandardModelLoader(
    ...
    sharder=shard_llama_model,
)

5. Using with Trainer

The model can be used with the fairseq2 trainer:

from fairseq2.models.loader import load_model
from fairseq2.recipes.trainer import Trainer, TrainUnit
from fairseq2.recipes.utils.asset import retrieve_asset_card

model_card = retrieve_asset_card("llama3_2_1b")

# Load model
model = load_model(
    model_card,
    device=Device("cpu")
)

# Create training unit
class YourTrainUnit(AbstractTrainUnit[SequenceBatch]):
    def __init__(self, model: YourModel) -> None:
        super().__init__(model)
    self._metric_bag = MetricBag()

    def __call__(self, batch: YourBatchType) -> tuple[Tensor, int]:
        loss = self._model(**batch)
        return loss, batch.num_targets

# Set up trainer
trainer = Trainer(
    unit=YourTrainUnit(model),
    data_reader=your_data_reader,
    optimizer=your_optimizer,
    # ... other trainer parameters
)

# Run training
trainer()

For a real-world example, see the fairseq2.recipes.lm recipe.

Best Practices

  1. Configuration:

    • Provide sensible defaults for all parameters

    • Document each config parameter

  2. Architecture Registry:

    • Use descriptive names for architectures

    • Keep base architectures simple

    • Document differences between architectures

  3. Model Loading:

    • Handle checkpoint format differences gracefully

    • Validate config parameters before model creation

    • Provide clear error messages for invalid configs

  4. Training Integration:

    • Create a dedicated training unit for your model

    • Implement proper metric tracking

    • Handle device placement and dtype conversion

Common Pitfalls

  1. Checkpoint Compatibility:

    • Ensure checkpoint conversion handles all parameter mappings

    • Verify tensor shapes and dtypes match

    • Handle missing or extra parameters gracefully

  2. Configuration Issues:

    • Validate all config parameters before use

    • Handle interdependent parameters correctly

    • Document any parameter constraints

  3. Training Problems:

    • Ensure proper device placement

    • Handle batch processing efficiently

    • Implement correct loss computation