The directory structure for a typical fairseq2 model looks like this:
fairseq2/models/
├──your_model/
│├──__init__.py
│├──archs.py# Defines model architectures
│├──factory.py# Contains model factory and config classes
│├──loader.py# Handles model loading and checkpoint conversion
│└──model.py# Actual model implementation
Note
The actual layout might vary depending on your implementation.
First, create a configuration class in factory.py:
fromdataclassesimportdataclassfromfairseq2.typingimportDataTypefromfairseq2.dataimportVocabularyInfo@dataclass(kw_only=True)classYourModelConfig:"""Configuration for YourModel."""# Basic model parametersmodel_dim:int=512"""The dimensionality of the model."""num_layers:int=6"""The number of layers in the model."""num_heads:int=8"""The number of attention heads in the model."""...
In the same file, create a registry for the model config:
Keep the architecture names descriptive and simple. Document differences between architectures.
Some real-world examples
Base Transformer Architecture
The base Transformer model provides a foundation that other models can build upon:
# In transformer/archs.pyfromfairseq2.models.transformer.factoryimportTransformerConfig,transformer_arch@transformer_arch("base")def_base()->TransformerConfig:"""Base architecture with default parameters."""returnTransformerConfig()@transformer_arch("big")def_big()->TransformerConfig:"""Larger architecture with modified parameters."""config=TransformerConfig()config.model_dim=1024config.num_encoder_attn_heads=16config.num_decoder_attn_heads=16config.ffn_inner_dim=4096config.dropout_p=0.3returnconfig
NLLB (No Language Left Behind)
NLLB extends the base Transformer architecture with specific configurations for multilingual translation:
# In nllb/archs.py@transformer_arch("nllb_dense_600m")def_dense_600m()->TransformerConfig:config=_dense_1b()# Inherits from larger architecture# Modify for smaller modelconfig.num_encoder_layers=12config.num_decoder_layers=12config.ffn_inner_dim=1024*4returnconfig@transformer_arch("nllb_dense_1b")def_dense_1b()->TransformerConfig:config=transformer_archs.get("base")# Start from base transformer# Customize for NLLBconfig.model_dim=1024config.vocab_info=VocabularyInfo(size=256206,unk_idx=1,bos_idx=2,eos_idx=3,pad_idx=0)config.num_encoder_layers=24config.num_decoder_layers=24config.num_encoder_attn_heads=16config.num_decoder_attn_heads=16config.ffn_inner_dim=1024*8config.norm_order=TransformerNormOrder.PREreturnconfig
LLaMA Architecture
LLaMA introduces its own configuration class with specific parameters for large language models:
# In llama/archs.py@llama_arch("7b")def_7b()->LLaMAConfig:"""7B parameter model."""returnLLaMAConfig()# Uses default parameters@llama_arch("13b")def_13b()->LLaMAConfig:"""13B parameter model."""config=_7b()config.model_dim=5120config.num_attn_heads=40config.num_key_value_heads=40config.ffn_inner_dim=5120*4returnconfig@llama_arch("llama2_70b")def_llama2_70b()->LLaMAConfig:"""LLaMA 2 70B parameter model."""config=_65b()config.max_seq_len=4096config.num_key_value_heads=8config.ffn_inner_dim=int(8192*4*1.3)# See A.2.1 in LLaMA 2config.ffn_inner_dim_to_multiple=4096returnconfig
Implement a factory function in factory.py that creates model instances:
defcreate_your_model(config:YourModelConfig)->YourModel:"""Create a model instance from config."""model=YourModel(model_dim=config.model_dim,num_layers=config.num_layers,num_heads=config.num_heads,dropout_p=config.dropout_p,vocab_info=config.vocab_info,)# Convert to specified dtypemodel.to(dtype=config.dtype)returnmodel
Some real-world examples
LLaMA Model Factory
We will use the fairseq2.models.llama.factory.create_llama_model function as an example.
The create_llama_model function serves as a factory method for instantiating a LLaMA model.
It encapsulates the process of building a model with the LLaMABuilder class, which constructs various components of the model based on the provided configuration.
This design pattern allows for a clean separation of model creation logic, making it easier for users to customize and extend the model architecture.
# In llama/factory.pyclassLLaMABuilder:...defbuild_model(self)->TransformerDecoderModel:"""Build a model."""decoder_frontend=self.build_decoder_frontend()decoder=self.build_decoder()final_proj=Linear(...)model=TransformerDecoderModel(decoder_frontend,decoder,final_proj,...)model.set_family(LLAMA_FAMILY)returnmodeldefcreate_llama_model(config:LLaMAConfig,*,device:Device|None=None,dtype:DataType|None=None,)->TransformerDecoderModel:"""Create a LLaMA model."""returnLLaMABuilder(config,device=device,dtype=dtype).build_model()model_factories.register(LLAMA_FAMILY,create_llama_model,LLaMAConfig,llama_archs)
create_llama_model instantiates your builder class and call the build_model method that actually creates the model as a TransformerDecoderModel.
Don’t forget to register your model with the fairseq2 model factories so that it can be easily instantiated later.
Create a loader in loader.py that handles model instantiation and checkpoint loading:
fromfairseq2.models.config_loaderimportStandardModelConfigLoaderfromfairseq2.models.loaderimportStandardModelLoader,load_model# Create config loaderload_your_model_config=StandardModelConfigLoader(YOUR_MODEL_FAMILY,YourModelConfig,your_model_archs)defconvert_your_model_checkpoint(checkpoint:dict[str,Any],config:YourModelConfig)->dict[str,Any]:"""Convert external checkpoints to fairseq2 format."""# Add checkpoint conversion logic herereturn{"model":checkpoint}# Create model loaderload_your_model=StandardModelLoader(config_loader=load_your_model_config,factory=create_your_model,checkpoint_converter=convert_your_model_checkpoint,)# Register loader with global registryload_model.register(YOUR_MODEL_FAMILY,load_your_model)
Some real-world examples on ckpt conversion
The convert_your_model_checkpoint function is a checkpoint converter that converts external checkpoints to fairseq2 format.
For example, in Mistral, the checkpoint format is different from fairseq2’s.
# In mistral/loader.pydefconvert_mistral_checkpoint(checkpoint:dict[str,Any],config:MistralConfig)->dict[str,Any]:"""Convert Mistral checkpoint to fairseq2 format."""if"model"incheckpoint:# Already in fairseq2 formatreturncheckpoint# Map parameter names from Mistral to fairseq2 formatkey_map={r"^layers\.([0-9]+)\.attention\.wq\.":r"decoder.layers.\1.self_attn.q_proj.",r"^layers\.([0-9]+)\.attention\.wk\.":r"decoder.layers.\1.self_attn.k_proj.",r"^layers\.([0-9]+)\.attention\.wv\.":r"decoder.layers.\1.self_attn.v_proj.",# ... more mappings}checkpoint=convert_model_state_dict(checkpoint,key_map)return{"model":checkpoint}
Overall, to support loading from different checkpoint formats:
Modify the checkpoint converter function
Add mapping logic for different parameter names
Handle any necessary tensor transformations
Advanced topic: Sharding
The sharder argument in StandardModelLoader is a function that shards the model, which is useful for distributed training.
This is natively supported by fairseq2, so you don’t need to implement it yourself.
For example, in LLaMA, the shard_llama_model function shards the model across multiple devices:
# In llama/loader.pyfromfairseq2.models.transformerimportshard_transformer_decoder_modelfromfairseq2.models.loaderimportStandardModelLoaderdefshard_llama_model(model:TransformerDecoderModel,config:LLaMAConfig,gangs:Mapping[str,Gang])->None:gang=gangs["tp"]# tensor parallelshard_embed_dim=config.max_seq_len<8192# LLaMA 1 or 2shard_transformer_decoder_model(model,gang,shard_embed_dim=shard_embed_dim)load_llama_model=StandardModelLoader(...sharder=shard_llama_model,)
fromfairseq2.models.loaderimportload_modelfromfairseq2.recipes.trainerimportTrainer,TrainUnitfromfairseq2.recipes.utils.assetimportretrieve_asset_cardmodel_card=retrieve_asset_card("llama3_2_1b")# Load modelmodel=load_model(model_card,device=Device("cpu"))# Create training unitclassYourTrainUnit(AbstractTrainUnit[SequenceBatch]):def__init__(self,model:YourModel)->None:super().__init__(model)self._metric_bag=MetricBag()def__call__(self,batch:YourBatchType)->tuple[Tensor,int]:loss=self._model(**batch)returnloss,batch.num_targets# Set up trainertrainer=Trainer(unit=YourTrainUnit(model),data_reader=your_data_reader,optimizer=your_optimizer,# ... other trainer parameters)# Run trainingtrainer()