Add Your Own Model

What you will learn
  • How to configure a model

  • How to implement a model class

  • How to create a model factory and handler

  • How to register your model with fairseq2

Prerequisites

Overview

The model system in fairseq2 consists of several key components:

  1. Model Config

    • Defines the architecture and hyperparameters

    • Supports different model variants through config presets

  2. Model Class

    • Implements the actual model architecture

    • Inherits from appropriate base classes

  3. Model Factory

    • Creates model instances from configs

    • The most important is to have a create_model method

  4. Model Handler

    • Manages model creation and checkpoint loading

    • Converts between different checkpoint formats

Directory Layout

The directory structure for a fairseq2 model typically looks like this:

src/fairseq2/models/your_model/
├── __init__.py
├── config.py      # Model configuration and presets
├── factory.py     # Model factory
├── handler.py     # Model handler for creation and loading
└── model.py       # Model implementation

Step-by-Step Guide

1. Define Model Configuration

First, create a configuration class in config.py:

from dataclasses import dataclass
from typing import Final

from fairseq2.context import RuntimeContext
from fairseq2.data import VocabularyInfo

@dataclass(kw_only=True)
class YourModelConfig:
    """Holds the configuration of your model."""

    model_dim: int = 512
    """The dimensionality of the model."""

    max_seq_len: int = 2048
    """The maximum sequence length."""

    vocab_info: VocabularyInfo
    """The vocabulary information."""

def register_your_model_configs(context: RuntimeContext) -> None:
    """Register model architecture presets."""
    registry = context.get_config_registry(YourModelConfig)

    arch = registry.decorator

    @arch("base")
    def your_model_base() -> YourModelConfig:
        return YourModelConfig(
            vocab_info=VocabularyInfo(
                size=32000,
                unk_idx=0,
                bos_idx=1,
                eos_idx=2,
                pad_idx=None
            )
        )

2. Create Model Class

Implement your model in model.py:

from typing import final

from torch import Tensor
from typing_extensions import override

from fairseq2.models.decoder import DecoderModel
from fairseq2.nn import IncrementalStateBag
from fairseq2.nn.padding import PaddingMask

@final
class YourModel(DecoderModel):
    """Your model implementation."""

    def __init__(
        self,
        model_dim: int,
        max_seq_len: int,
        vocab_info: VocabularyInfo,
    ) -> None:
        super().__init__(model_dim, max_seq_len, vocab_info)

        # Initialize your model components here

    @override
    def decode(
        self,
        seqs: Tensor,
        padding_mask: PaddingMask | None,
        *,
        state_bag: IncrementalStateBag | None = None,
    ) -> tuple[Tensor, PaddingMask]:
        # Implement your decoding logic
        pass

3. Implement Model Factory

Note

This factory pattern is a convention and not strictly required. It is helpful to subclass and change specific parts of the model construction logic if needed. The most important is to have a create_model(config: YourModelConfig) -> YourModel method to integrate with fairseq2.

Create a factory in factory.py:

from fairseq2.models.your_model._config import YourModelConfig
from fairseq2.models.your_model._model import YourModel

class YourModelFactory:
    """Creates model instances."""

    _config: YourModelConfig

    def __init__(self, config: YourModelConfig) -> None:
        self._config = config

    def create_model(self) -> YourModel:
        """Creates a model instance."""
        config = self._config

        return YourModel(
            model_dim=config.model_dim,
            max_seq_len=config.max_seq_len,
            vocab_info=config.vocab_info,
        )

4. Create Model Handler

Implement a handler in handler.py:

from typing import cast

from torch.nn import Module
from typing_extensions import override

from fairseq2.models import AbstractModelHandler
from fairseq2.models.your_model._config import YourModelConfig
from fairseq2.models.your_model._factory import YourModelFactory
from fairseq2.models.your_model._model import YourModel

class YourModelHandler(AbstractModelHandler):
    # A 'family' represents a group of related models sharing a common
    # architecture. For instance, `llama` is the model family of
    # `llama_3_2_8b_instruct`.
    @override
    @property
    def family(self) -> str:
        return "my_model_family"

    @override
    @property
    def kls(self) -> type[Module]:
        return YourModel

    @override
    def _create_model(self, config: object) -> Module:
        config = cast(YourModelConfig, config)

        return YourModelFactory(config).create_model()

5. Register the Model

Add to your setup_fairseq2_extension:

def setup_fairseq2_extension(context: RuntimeContext) -> None:
    # fairseq2's global model registry.
    model_registry = context.get_registry(ModelHandler)

    # Registry my model.
    configs = context.get_config_registry(YourModelConfig)

    default_arch = "base"

    handler = YourModelHandler(
        configs, default_arch, asset_download_manager, tensor_loader
    )

    model_registry.register(handler.family, handler)

    # Register my model architecture configurations.
    register_your_model_configs(context)

Best Practices

  1. Configuration:
    • Make all parameters type-safe and well-documented

    • Use sensible defaults

    • Register different architectures as config presets

  2. Model Implementation:
    • Inherit from appropriate base classes

    • Use type hints and proper documentation

    • Implement all required abstract methods

  3. Checkpoint Loading:
    • Handle different checkpoint formats gracefully

    • Use convert_model_state_dict for key mapping

    • Validate checkpoint contents

  4. Testing:
    • Add unit tests for model components

    • Test checkpoint loading

    • Verify model outputs

Common Pitfalls

  1. Type Safety:
    • Always use type hints

    • Validate config parameters

  2. Checkpoint Compatibility:
    • Handle missing or extra parameters

    • Verify tensor shapes and dtypes

    • Document supported checkpoint formats

  3. Model Registration:
    • Register configs before using them

    • Set appropriate default architecture

    • Handle dependencies correctly