✎ Load Model¶
Overview¶
This notebook aims at illustrating on how to instantiate models in fairseq2.
[1]:
from fairseq2 import setup_fairseq2
# Always call setup_fairseq2() before using any fairseq2 functionality
setup_fairseq2()
All models in fairseq2 inherit from PyTorch’s nn.Module
, providing standard PyTorch funtionality. The configuration can be easily customized.
[ ]:
from fairseq2.models.llama import LLaMAConfig, create_llama_model
from fairseq2.data import VocabularyInfo
custom_config = LLaMAConfig(
model_dim=2048, # Model dimension
max_seq_len=4096, # Maximum sequence length
vocab_info=VocabularyInfo(
size=32000, # Vocabulary size
unk_idx=0, # Unknown index
bos_idx=1, # Beginning of sequence index
eos_idx=2, # End of sequence index
pad_idx=None, # Padding index
),
num_layers=16, # Number of transformer layers
num_attn_heads=32, # Number of attention heads
num_key_value_heads=8, # Number of key/value heads
ffn_inner_dim=2048 * 4, # FFN inner dimension
dropout_p=0.1, # Dropout probability
)
# this will initialize a model with random weights
model = create_llama_model(custom_config)
model
TransformerDecoderModel(
model_dim=2048
(decoder_frontend): TransformerEmbeddingFrontend(
model_dim=2048
(embed): StandardEmbedding(num_embeddings=32000, embedding_dim=2048, init_fn=init_embed)
(pos_encoder): None
(layer_norm): None
(dropout): Dropout(p=0.1, inplace=False)
)
(decoder): StandardTransformerDecoder(
model_dim=2048, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE
(layers): ModuleList(
(0-15): 16 x StandardTransformerDecoderLayer(
model_dim=2048, norm_order=PRE
(self_attn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
(self_attn): StandardMultiheadAttention(
num_heads=32, model_dim=2048, num_key_value_heads=8
(q_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
(k_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
(v_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
(pos_encoder): RotaryEncoder(encoding_dim=64, max_seq_len=4096)
(sdpa): TorchSDPA(attn_dropout_p=0.1)
(output_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
)
(self_attn_norm): None
(self_attn_dropout): None
(self_attn_residual): StandardResidualConnect()
(encoder_decoder_attn): None
(encoder_decoder_attn_dropout): None
(encoder_decoder_attn_residual): None
(encoder_decoder_attn_layer_norm): None
(ffn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
(ffn): GLUFeedForwardNetwork(
model_dim=2048, inner_dim_scale=0.666667, inner_dim_to_multiple=256
(gate_proj): Linear(input_dim=2048, output_dim=5632, bias=False, init_fn=init_projection)
(gate_activation): SiLU()
(inner_proj): Linear(input_dim=2048, output_dim=5632, bias=False, init_fn=init_projection)
(inner_dropout): Dropout(p=0.1, inplace=False)
(output_proj): Linear(input_dim=5632, output_dim=2048, bias=False, init_fn=init_projection)
)
(ffn_dropout): None
(ffn_residual): StandardResidualConnect()
)
)
(layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
(dropout): Dropout(p=0.1, inplace=False)
)
(final_proj): Linear(input_dim=2048, output_dim=32000, bias=False, init_fn=init_projection)
)
[3]:
# the model is initialized on CPU with default dtype
print(f"Initial device: {next(model.parameters()).device}")
print(f"Initial dtype: {next(model.parameters()).dtype}")
Initial device: cpu
Initial dtype: torch.float32
[5]:
import torch
# you can also move the model to GPU with bfloat16 dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 # Modern GPUs (e.g. H100) perform well with bfloat16
model = model.to(device=device, dtype=dtype)
# Verify the change
print("After moving:")
print(f"Device: {next(model.parameters()).device}")
print(f"Dtype: {next(model.parameters()).dtype}")
After moving:
Device: cuda:0
Dtype: torch.bfloat16
Create Model from Hub¶
Using Model Config¶
You can fetch some registered configs available in model hub.
[8]:
from fairseq2.models.llama import get_llama_model_hub, create_llama_model
model_hub = get_llama_model_hub()
model_config = model_hub.load_config(
"llama3_1_8b_instruct"
) # use llama3.1 8b preset as an example
llama_model = create_llama_model(model_config)
llama_model
[8]:
TransformerDecoderModel(
model_dim=4096
(decoder_frontend): TransformerEmbeddingFrontend(
model_dim=4096
(embed): StandardEmbedding(num_embeddings=128256, embedding_dim=4096)
(pos_encoder): None
(layer_norm): None
(dropout): Dropout(p=0.1, inplace=False)
)
(decoder): StandardTransformerDecoder(
model_dim=4096, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE
(layers): ModuleList(
(0-31): 32 x StandardTransformerDecoderLayer(
model_dim=4096, norm_order=PRE
(self_attn_layer_norm): RMSNorm(normalized_shape=(4096,), eps=1E-05, elementwise_affine=True)
(self_attn): StandardMultiheadAttention(
num_heads=32, model_dim=4096, num_key_value_heads=8
(q_proj): Linear(input_dim=4096, output_dim=4096, bias=False, init_fn=init_qkv_projection)
(k_proj): Linear(input_dim=4096, output_dim=1024, bias=False, init_fn=init_qkv_projection)
(v_proj): Linear(input_dim=4096, output_dim=1024, bias=False, init_fn=init_qkv_projection)
(pos_encoder): RotaryEncoder(encoding_dim=128, max_seq_len=131072)
(sdpa): TorchSDPA(attn_dropout_p=0.1)
(output_proj): Linear(input_dim=4096, output_dim=4096, bias=False, init_fn=init_output_projection)
)
(self_attn_norm): None
(self_attn_dropout): None
(self_attn_residual): StandardResidualConnect()
(encoder_decoder_attn): None
(encoder_decoder_attn_dropout): None
(encoder_decoder_attn_residual): None
(encoder_decoder_attn_layer_norm): None
(ffn_layer_norm): RMSNorm(normalized_shape=(4096,), eps=1E-05, elementwise_affine=True)
(ffn): GLUFeedForwardNetwork(
model_dim=4096, inner_dim_scale=0.666667, inner_dim_to_multiple=1024
(gate_proj): Linear(input_dim=4096, output_dim=14336, bias=False)
(gate_activation): SiLU()
(inner_proj): Linear(input_dim=4096, output_dim=14336, bias=False)
(inner_dropout): Dropout(p=0.1, inplace=False)
(output_proj): Linear(input_dim=14336, output_dim=4096, bias=False)
)
(ffn_dropout): None
(ffn_residual): StandardResidualConnect()
)
)
(layer_norm): RMSNorm(normalized_shape=(4096,), eps=1E-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(final_proj): Linear(input_dim=4096, output_dim=128256, bias=False, init_fn=init_final_projection)
)
Directly Using Registered Model Name¶
To check what are the registered models, we can leverage the asset_store
in our runtime context, which provides a centralized way to access global resources and services throughout the codebase.
The asset_store
is a key component that manages model assets and their configurations.
The runtime context is particularly important for fairseq2’s extensibility:
It allows for registering custom models, configs, assets etc.
It provides a unified interface for accessing these resources
It can be customized to support different backends or storage systems
[7]:
from fairseq2.context import get_runtime_context
context = get_runtime_context()
asset_store = context.asset_store
[10]:
[asset for asset in asset_store.retrieve_names() if "llama3_1" in asset]
[10]:
['llama3_1_8b@',
'llama3_1_8b_instruct@',
'llama3_1_70b@',
'llama3_1_70b_instruct@',
'llama3_1_8b@cluster0',
'llama3_1_8b@cluster3',
'llama3_1_8b_instruct@cluster1',
'llama3_1_8b_instruct@cluster0',
'llama3_1_8b_instruct@cluster3',
'llama3_1_70b@cluster0',
'llama3_1_70b@cluster3',
'llama3_1_70b_instruct@cluster1',
'llama3_1_70b_instruct@cluster0',
'llama3_1_70b_instruct@cluster3',
'llama3_1_8b@cluster2',
'llama3_1_8b@cluster4',
'llama3_1_8b_instruct@cluster2',
'llama3_1_8b_instruct@cluster4',
'llama3_1_70b@cluster2',
'llama3_1_70b@cluster4',
'llama3_1_70b_instruct@cluster2',
'llama3_1_70b_instruct@cluster4']
Loading pretrained model can also be done directly from the hub.
[11]:
from fairseq2.models.llama import get_llama_model_hub
model_hub = get_llama_model_hub()
# Load a pre-trained model from the hub
model = model_hub.load(
"llama3_2_1b"
) # here llama3_2_1b needs to be a registered asset card
model
[11]:
TransformerDecoderModel(
model_dim=2048
(decoder_frontend): TransformerEmbeddingFrontend(
model_dim=2048
(embed): StandardEmbedding(num_embeddings=128256, embedding_dim=2048, init_fn=init_embed)
(pos_encoder): None
(layer_norm): None
(dropout): None
)
(decoder): StandardTransformerDecoder(
model_dim=2048, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE
(layers): ModuleList(
(0-15): 16 x StandardTransformerDecoderLayer(
model_dim=2048, norm_order=PRE
(self_attn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
(self_attn): StandardMultiheadAttention(
num_heads=32, model_dim=2048, num_key_value_heads=8
(q_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
(k_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
(v_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
(pos_encoder): RotaryEncoder(encoding_dim=64, max_seq_len=131072)
(sdpa): TorchSDPA(attn_dropout_p=0)
(output_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
)
(self_attn_norm): None
(self_attn_dropout): None
(self_attn_residual): StandardResidualConnect()
(encoder_decoder_attn): None
(encoder_decoder_attn_dropout): None
(encoder_decoder_attn_residual): None
(encoder_decoder_attn_layer_norm): None
(ffn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
(ffn): GLUFeedForwardNetwork(
model_dim=2048, inner_dim_scale=0.666667, inner_dim_to_multiple=256
(gate_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)
(gate_activation): SiLU()
(inner_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)
(inner_dropout): None
(output_proj): Linear(input_dim=8192, output_dim=2048, bias=False, init_fn=init_projection)
)
(ffn_dropout): None
(ffn_residual): StandardResidualConnect()
)
)
(layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
(dropout): None
)
(final_proj): TiedProjection(input_dim=2048, output_dim=128256)
)
Using Model Card¶
We can also directly load model from model card.
[18]:
model_card = asset_store.retrieve_card("llama3_2_1b")
model_card
[18]:
{'base': 'llama3', 'model_arch': 'llama3_2_1b', '__base_path__': PosixPath('/fsx-checkpoints/yaoj/envs/fs2_nightly_pt25_cu121/conda/lib/python3.10/site-packages/fairseq2_ext/cards/models'), '__source__': 'package:fairseq2_ext.cards', 'checkpoint': '/fsx-ram/shared/Llama-3.2-1B/original/consolidated.00.pth', 'name': 'llama3_2_1b'}
[ ]:
llama_model = model = model_hub.load(model_card)
llama_model
TransformerDecoderModel(
model_dim=2048
(decoder_frontend): TransformerEmbeddingFrontend(
model_dim=2048
(embed): StandardEmbedding(num_embeddings=128256, embedding_dim=2048, init_fn=init_embed)
(pos_encoder): None
(layer_norm): None
(dropout): None
)
(decoder): StandardTransformerDecoder(
model_dim=2048, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE
(layers): ModuleList(
(0-15): 16 x StandardTransformerDecoderLayer(
model_dim=2048, norm_order=PRE
(self_attn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
(self_attn): StandardMultiheadAttention(
num_heads=32, model_dim=2048, num_key_value_heads=8
(q_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
(k_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
(v_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)
(pos_encoder): RotaryEncoder(encoding_dim=64, max_seq_len=131072)
(sdpa): TorchSDPA(attn_dropout_p=0)
(output_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)
)
(self_attn_norm): None
(self_attn_dropout): None
(self_attn_residual): StandardResidualConnect()
(encoder_decoder_attn): None
(encoder_decoder_attn_dropout): None
(encoder_decoder_attn_residual): None
(encoder_decoder_attn_layer_norm): None
(ffn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
(ffn): GLUFeedForwardNetwork(
model_dim=2048, inner_dim_scale=0.666667, inner_dim_to_multiple=256
(gate_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)
(gate_activation): SiLU()
(inner_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)
(inner_dropout): None
(output_proj): Linear(input_dim=8192, output_dim=2048, bias=False, init_fn=init_projection)
)
(ffn_dropout): None
(ffn_residual): StandardResidualConnect()
)
)
(layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)
(dropout): None
)
(final_proj): TiedProjection(input_dim=2048, output_dim=128256)
)