✎ Load Model

Overview

This notebook aims at illustrating on how to instantiate models in fairseq2.

[1]:
from fairseq2 import setup_fairseq2

# Always call setup_fairseq2() before using any fairseq2 functionality
setup_fairseq2()

All models in fairseq2 inherit from PyTorch’s nn.Module, providing standard PyTorch funtionality. The configuration can be easily customized.

[ ]:
from fairseq2.models.llama import LLaMAConfig, create_llama_model
from fairseq2.data import VocabularyInfo

custom_config = LLaMAConfig(
    model_dim=2048,  # Model dimension
    max_seq_len=4096,  # Maximum sequence length
    vocab_info=VocabularyInfo(
        size=32000,  # Vocabulary size
        unk_idx=0,  # Unknown index
        bos_idx=1,  # Beginning of sequence index
        eos_idx=2,  # End of sequence index
        pad_idx=None,  # Padding index
    ),
    num_layers=16,  # Number of transformer layers
    num_attn_heads=32,  # Number of attention heads
    num_key_value_heads=8,  # Number of key/value heads
    ffn_inner_dim=2048 * 4,  # FFN inner dimension
    dropout_p=0.1,  # Dropout probability
)

# this will initialize a model with random weights
model = create_llama_model(custom_config)
model
TransformerDecoderModel(
  model_dim=2048
  (decoder_frontend): TransformerEmbeddingFrontend(
    model_dim=2048
    (embed): StandardEmbedding(num_embeddings=32000, embedding_dim=2048, init_fn=init_embed)
    (pos_encoder): None
    (layer_norm): None
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (decoder): StandardTransformerDecoder(
    model_dim=2048, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE
    (layers): ModuleList(
      (0-15): 16 x StandardTransformerDecoderLayer(
        model_dim=2048, norm_order=PRE
        (self_attn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
        (self_attn): StandardMultiheadAttention(
          num_heads=32, model_dim=2048, num_key_value_heads=8
          (q_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
          (k_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
          (v_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
          (pos_encoder): RotaryEncoder(encoding_dim=64, max_seq_len=4096)
          (sdpa): TorchSDPA(attn_dropout_p=0.1)
          (output_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
        )
        (self_attn_norm): None
        (self_attn_dropout): None
        (self_attn_residual): StandardResidualConnect()
        (encoder_decoder_attn): None
        (encoder_decoder_attn_dropout): None
        (encoder_decoder_attn_residual): None
        (encoder_decoder_attn_layer_norm): None
        (ffn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
        (ffn): GLUFeedForwardNetwork(
          model_dim=2048, inner_dim_scale=0.666667, inner_dim_to_multiple=256
          (gate_proj): Linear(input_dim=2048, output_dim=5632, bias=False, init_fn=init_projection)
          (gate_activation): SiLU()
          (inner_proj): Linear(input_dim=2048, output_dim=5632, bias=False, init_fn=init_projection)
          (inner_dropout): Dropout(p=0.1, inplace=False)
          (output_proj): Linear(input_dim=5632, output_dim=2048, bias=False, init_fn=init_projection)
        )
        (ffn_dropout): None
        (ffn_residual): StandardResidualConnect()
      )
    )
    (layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (final_proj): Linear(input_dim=2048, output_dim=32000, bias=False, init_fn=init_projection)
)
[3]:
# the model is initialized on CPU with default dtype
print(f"Initial device: {next(model.parameters()).device}")
print(f"Initial dtype: {next(model.parameters()).dtype}")
Initial device: cpu
Initial dtype: torch.float32
[5]:
import torch

# you can also move the model to GPU with bfloat16 dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16  # Modern GPUs (e.g. H100) perform well with bfloat16

model = model.to(device=device, dtype=dtype)

# Verify the change
print("After moving:")
print(f"Device: {next(model.parameters()).device}")
print(f"Dtype: {next(model.parameters()).dtype}")
After moving:
Device: cuda:0
Dtype: torch.bfloat16

Create Model from Hub

Using Model Config

You can fetch some registered configs available in model hub.

[8]:
from fairseq2.models.llama import get_llama_model_hub, create_llama_model

model_hub = get_llama_model_hub()
model_config = model_hub.load_config(
    "llama3_1_8b_instruct"
)  # use llama3.1 8b preset as an example

llama_model = create_llama_model(model_config)
llama_model
[8]:
TransformerDecoderModel(
  model_dim=4096
  (decoder_frontend): TransformerEmbeddingFrontend(
    model_dim=4096
    (embed): StandardEmbedding(num_embeddings=128256, embedding_dim=4096)
    (pos_encoder): None
    (layer_norm): None
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (decoder): StandardTransformerDecoder(
    model_dim=4096, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE
    (layers): ModuleList(
      (0-31): 32 x StandardTransformerDecoderLayer(
        model_dim=4096, norm_order=PRE
        (self_attn_layer_norm): RMSNorm(normalized_shape=(4096,), eps=1E-05, elementwise_affine=True)
        (self_attn): StandardMultiheadAttention(
          num_heads=32, model_dim=4096, num_key_value_heads=8
          (q_proj): Linear(input_dim=4096, output_dim=4096, bias=False, init_fn=init_qkv_projection)
          (k_proj): Linear(input_dim=4096, output_dim=1024, bias=False, init_fn=init_qkv_projection)
          (v_proj): Linear(input_dim=4096, output_dim=1024, bias=False, init_fn=init_qkv_projection)
          (pos_encoder): RotaryEncoder(encoding_dim=128, max_seq_len=131072)
          (sdpa): TorchSDPA(attn_dropout_p=0.1)
          (output_proj): Linear(input_dim=4096, output_dim=4096, bias=False, init_fn=init_output_projection)
        )
        (self_attn_norm): None
        (self_attn_dropout): None
        (self_attn_residual): StandardResidualConnect()
        (encoder_decoder_attn): None
        (encoder_decoder_attn_dropout): None
        (encoder_decoder_attn_residual): None
        (encoder_decoder_attn_layer_norm): None
        (ffn_layer_norm): RMSNorm(normalized_shape=(4096,), eps=1E-05, elementwise_affine=True)
        (ffn): GLUFeedForwardNetwork(
          model_dim=4096, inner_dim_scale=0.666667, inner_dim_to_multiple=1024
          (gate_proj): Linear(input_dim=4096, output_dim=14336, bias=False)
          (gate_activation): SiLU()
          (inner_proj): Linear(input_dim=4096, output_dim=14336, bias=False)
          (inner_dropout): Dropout(p=0.1, inplace=False)
          (output_proj): Linear(input_dim=14336, output_dim=4096, bias=False)
        )
        (ffn_dropout): None
        (ffn_residual): StandardResidualConnect()
      )
    )
    (layer_norm): RMSNorm(normalized_shape=(4096,), eps=1E-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (final_proj): Linear(input_dim=4096, output_dim=128256, bias=False, init_fn=init_final_projection)
)

Directly Using Registered Model Name

To check what are the registered models, we can leverage the asset_store in our runtime context, which provides a centralized way to access global resources and services throughout the codebase.

The asset_store is a key component that manages model assets and their configurations.

The runtime context is particularly important for fairseq2’s extensibility:

  1. It allows for registering custom models, configs, assets etc.

  2. It provides a unified interface for accessing these resources

  3. It can be customized to support different backends or storage systems

[7]:
from fairseq2.context import get_runtime_context

context = get_runtime_context()
asset_store = context.asset_store
[10]:
[asset for asset in asset_store.retrieve_names() if "llama3_1" in asset]
[10]:
['llama3_1_8b@',
 'llama3_1_8b_instruct@',
 'llama3_1_70b@',
 'llama3_1_70b_instruct@',
 'llama3_1_8b@cluster0',
 'llama3_1_8b@cluster3',
 'llama3_1_8b_instruct@cluster1',
 'llama3_1_8b_instruct@cluster0',
 'llama3_1_8b_instruct@cluster3',
 'llama3_1_70b@cluster0',
 'llama3_1_70b@cluster3',
 'llama3_1_70b_instruct@cluster1',
 'llama3_1_70b_instruct@cluster0',
 'llama3_1_70b_instruct@cluster3',
 'llama3_1_8b@cluster2',
 'llama3_1_8b@cluster4',
 'llama3_1_8b_instruct@cluster2',
 'llama3_1_8b_instruct@cluster4',
 'llama3_1_70b@cluster2',
 'llama3_1_70b@cluster4',
 'llama3_1_70b_instruct@cluster2',
 'llama3_1_70b_instruct@cluster4']

Loading pretrained model can also be done directly from the hub.

[11]:
from fairseq2.models.llama import get_llama_model_hub

model_hub = get_llama_model_hub()
# Load a pre-trained model from the hub
model = model_hub.load(
    "llama3_2_1b"
)  # here llama3_2_1b needs to be a registered asset card
model
[11]:
TransformerDecoderModel(
  model_dim=2048
  (decoder_frontend): TransformerEmbeddingFrontend(
    model_dim=2048
    (embed): StandardEmbedding(num_embeddings=128256, embedding_dim=2048, init_fn=init_embed)
    (pos_encoder): None
    (layer_norm): None
    (dropout): None
  )
  (decoder): StandardTransformerDecoder(
    model_dim=2048, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE
    (layers): ModuleList(
      (0-15): 16 x StandardTransformerDecoderLayer(
        model_dim=2048, norm_order=PRE
        (self_attn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
        (self_attn): StandardMultiheadAttention(
          num_heads=32, model_dim=2048, num_key_value_heads=8
          (q_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
          (k_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
          (v_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
          (pos_encoder): RotaryEncoder(encoding_dim=64, max_seq_len=131072)
          (sdpa): TorchSDPA(attn_dropout_p=0)
          (output_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
        )
        (self_attn_norm): None
        (self_attn_dropout): None
        (self_attn_residual): StandardResidualConnect()
        (encoder_decoder_attn): None
        (encoder_decoder_attn_dropout): None
        (encoder_decoder_attn_residual): None
        (encoder_decoder_attn_layer_norm): None
        (ffn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
        (ffn): GLUFeedForwardNetwork(
          model_dim=2048, inner_dim_scale=0.666667, inner_dim_to_multiple=256
          (gate_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)
          (gate_activation): SiLU()
          (inner_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)
          (inner_dropout): None
          (output_proj): Linear(input_dim=8192, output_dim=2048, bias=False, init_fn=init_projection)
        )
        (ffn_dropout): None
        (ffn_residual): StandardResidualConnect()
      )
    )
    (layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
    (dropout): None
  )
  (final_proj): TiedProjection(input_dim=2048, output_dim=128256)
)

Using Model Card

We can also directly load model from model card.

[18]:
model_card = asset_store.retrieve_card("llama3_2_1b")
model_card
[18]:
{'base': 'llama3', 'model_arch': 'llama3_2_1b', '__base_path__': PosixPath('/fsx-checkpoints/yaoj/envs/fs2_nightly_pt25_cu121/conda/lib/python3.10/site-packages/fairseq2_ext/cards/models'), '__source__': 'package:fairseq2_ext.cards', 'checkpoint': '/fsx-ram/shared/Llama-3.2-1B/original/consolidated.00.pth', 'name': 'llama3_2_1b'}
[ ]:
llama_model = model = model_hub.load(model_card)
llama_model
TransformerDecoderModel(
  model_dim=2048
  (decoder_frontend): TransformerEmbeddingFrontend(
    model_dim=2048
    (embed): StandardEmbedding(num_embeddings=128256, embedding_dim=2048, init_fn=init_embed)
    (pos_encoder): None
    (layer_norm): None
    (dropout): None
  )
  (decoder): StandardTransformerDecoder(
    model_dim=2048, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE
    (layers): ModuleList(
      (0-15): 16 x StandardTransformerDecoderLayer(
        model_dim=2048, norm_order=PRE
        (self_attn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
        (self_attn): StandardMultiheadAttention(
          num_heads=32, model_dim=2048, num_key_value_heads=8
          (q_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
          (k_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
          (v_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
          (pos_encoder): RotaryEncoder(encoding_dim=64, max_seq_len=131072)
          (sdpa): TorchSDPA(attn_dropout_p=0)
          (output_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
        )
        (self_attn_norm): None
        (self_attn_dropout): None
        (self_attn_residual): StandardResidualConnect()
        (encoder_decoder_attn): None
        (encoder_decoder_attn_dropout): None
        (encoder_decoder_attn_residual): None
        (encoder_decoder_attn_layer_norm): None
        (ffn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
        (ffn): GLUFeedForwardNetwork(
          model_dim=2048, inner_dim_scale=0.666667, inner_dim_to_multiple=256
          (gate_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)
          (gate_activation): SiLU()
          (inner_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)
          (inner_dropout): None
          (output_proj): Linear(input_dim=8192, output_dim=2048, bias=False, init_fn=init_projection)
        )
        (ffn_dropout): None
        (ffn_residual): StandardResidualConnect()
      )
    )
    (layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
    (dropout): None
  )
  (final_proj): TiedProjection(input_dim=2048, output_dim=128256)
)