The directory structure for a fairseq2 model typically looks like this:
src/fairseq2/models/your_model/
├──__init__.py
├──config.py# Model configuration and presets
├──factory.py# Model factory
├──handler.py# Model handler for creation and loading
└──model.py# Model implementation
fromdataclassesimportdataclassfromtypingimportFinalfromfairseq2.contextimportRuntimeContextfromfairseq2.dataimportVocabularyInfo@dataclass(kw_only=True)classYourModelConfig:"""Holds the configuration of your model."""model_dim:int=512"""The dimensionality of the model."""max_seq_len:int=2048"""The maximum sequence length."""vocab_info:VocabularyInfo"""The vocabulary information."""defregister_your_model_configs(context:RuntimeContext)->None:"""Register model architecture presets."""registry=context.get_config_registry(YourModelConfig)arch=registry.decorator@arch("base")defyour_model_base()->YourModelConfig:returnYourModelConfig(vocab_info=VocabularyInfo(size=32000,unk_idx=0,bos_idx=1,eos_idx=2,pad_idx=None))
fromtypingimportfinalfromtorchimportTensorfromtyping_extensionsimportoverridefromfairseq2.models.decoderimportDecoderModelfromfairseq2.nnimportIncrementalStateBagfromfairseq2.nn.paddingimportPaddingMask@finalclassYourModel(DecoderModel):"""Your model implementation."""def__init__(self,model_dim:int,max_seq_len:int,vocab_info:VocabularyInfo,)->None:super().__init__(model_dim,max_seq_len,vocab_info)# Initialize your model components here@overridedefdecode(self,seqs:Tensor,padding_mask:PaddingMask|None,*,state_bag:IncrementalStateBag|None=None,)->tuple[Tensor,PaddingMask]:# Implement your decoding logicpass
This factory pattern is a convention and not strictly required.
It is helpful to subclass and change specific parts of the model construction logic if needed.
The most important is to have a create_model(config:YourModelConfig)->YourModel method to integrate with fairseq2.
Create a factory in factory.py:
fromfairseq2.models.your_model._configimportYourModelConfigfromfairseq2.models.your_model._modelimportYourModelclassYourModelFactory:"""Creates model instances."""_config:YourModelConfigdef__init__(self,config:YourModelConfig)->None:self._config=configdefcreate_model(self)->YourModel:"""Creates a model instance."""config=self._configreturnYourModel(model_dim=config.model_dim,max_seq_len=config.max_seq_len,vocab_info=config.vocab_info,)
fromtypingimportcastfromtorch.nnimportModulefromtyping_extensionsimportoverridefromfairseq2.modelsimportAbstractModelHandlerfromfairseq2.models.your_model._configimportYourModelConfigfromfairseq2.models.your_model._factoryimportYourModelFactoryfromfairseq2.models.your_model._modelimportYourModelclassYourModelHandler(AbstractModelHandler):# A 'family' represents a group of related models sharing a common# architecture. For instance, `llama` is the model family of# `llama_3_2_8b_instruct`.@override@propertydeffamily(self)->str:return"my_model_family"@override@propertydefkls(self)->type[Module]:returnYourModel@overridedef_create_model(self,config:object)->Module:config=cast(YourModelConfig,config)returnYourModelFactory(config).create_model()
defsetup_fairseq2_extension(context:RuntimeContext)->None:# fairseq2's global model registry.model_registry=context.get_registry(ModelHandler)# Registry my model.configs=context.get_config_registry(YourModelConfig)default_arch="base"handler=YourModelHandler(configs,default_arch,asset_download_manager,tensor_loader)model_registry.register(handler.family,handler)# Register my model architecture configurations.register_your_model_configs(context)