{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ✎ Load Model\n", "\n", "## Overview\n", "\n", "This notebook aims at illustrating on how to instantiate models in fairseq2." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from fairseq2 import setup_fairseq2\n", "\n", "# Always call setup_fairseq2() before using any fairseq2 functionality\n", "setup_fairseq2()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All models in fairseq2 inherit from PyTorch's `nn.Module`, providing standard PyTorch funtionality. The configuration can be easily customized." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TransformerDecoderModel(\n", " model_dim=2048\n", " (decoder_frontend): TransformerEmbeddingFrontend(\n", " model_dim=2048\n", " (embed): StandardEmbedding(num_embeddings=32000, embedding_dim=2048, init_fn=init_embed)\n", " (pos_encoder): None\n", " (layer_norm): None\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (decoder): StandardTransformerDecoder(\n", " model_dim=2048, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE\n", " (layers): ModuleList(\n", " (0-15): 16 x StandardTransformerDecoderLayer(\n", " model_dim=2048, norm_order=PRE\n", " (self_attn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)\n", " (self_attn): StandardMultiheadAttention(\n", " num_heads=32, model_dim=2048, num_key_value_heads=8\n", " (q_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)\n", " (k_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)\n", " (v_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)\n", " (pos_encoder): RotaryEncoder(encoding_dim=64, max_seq_len=4096)\n", " (sdpa): TorchSDPA(attn_dropout_p=0.1)\n", " (output_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)\n", " )\n", " (self_attn_norm): None\n", " (self_attn_dropout): None\n", " (self_attn_residual): StandardResidualConnect()\n", " (encoder_decoder_attn): None\n", " (encoder_decoder_attn_dropout): None\n", " (encoder_decoder_attn_residual): None\n", " (encoder_decoder_attn_layer_norm): None\n", " (ffn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)\n", " (ffn): GLUFeedForwardNetwork(\n", " model_dim=2048, inner_dim_scale=0.666667, inner_dim_to_multiple=256\n", " (gate_proj): Linear(input_dim=2048, output_dim=5632, bias=False, init_fn=init_projection)\n", " (gate_activation): SiLU()\n", " (inner_proj): Linear(input_dim=2048, output_dim=5632, bias=False, init_fn=init_projection)\n", " (inner_dropout): Dropout(p=0.1, inplace=False)\n", " (output_proj): Linear(input_dim=5632, output_dim=2048, bias=False, init_fn=init_projection)\n", " )\n", " (ffn_dropout): None\n", " (ffn_residual): StandardResidualConnect()\n", " )\n", " )\n", " (layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (final_proj): Linear(input_dim=2048, output_dim=32000, bias=False, init_fn=init_projection)\n", ")" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from fairseq2.models.llama import LLaMAConfig, create_llama_model\n", "from fairseq2.data import VocabularyInfo\n", "\n", "custom_config = LLaMAConfig(\n", " model_dim=2048, # Model dimension\n", " max_seq_len=4096, # Maximum sequence length\n", " vocab_info=VocabularyInfo(\n", " size=32000, # Vocabulary size\n", " unk_idx=0, # Unknown index\n", " bos_idx=1, # Beginning of sequence index\n", " eos_idx=2, # End of sequence index\n", " pad_idx=None, # Padding index\n", " ),\n", " num_layers=16, # Number of transformer layers\n", " num_attn_heads=32, # Number of attention heads\n", " num_key_value_heads=8, # Number of key/value heads\n", " ffn_inner_dim=2048 * 4, # FFN inner dimension\n", " dropout_p=0.1, # Dropout probability\n", ")\n", "\n", "# this will initialize a model with random weights\n", "model = create_llama_model(custom_config)\n", "model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial device: cpu\n", "Initial dtype: torch.float32\n" ] } ], "source": [ "# the model is initialized on CPU with default dtype\n", "print(f\"Initial device: {next(model.parameters()).device}\")\n", "print(f\"Initial dtype: {next(model.parameters()).dtype}\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "After moving:\n", "Device: cuda:0\n", "Dtype: torch.bfloat16\n" ] } ], "source": [ "import torch\n", "\n", "# you can also move the model to GPU with bfloat16 dtype\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "dtype = torch.bfloat16 # Modern GPUs (e.g. H100) perform well with bfloat16\n", "\n", "model = model.to(device=device, dtype=dtype)\n", "\n", "# Verify the change\n", "print(\"After moving:\")\n", "print(f\"Device: {next(model.parameters()).device}\")\n", "print(f\"Dtype: {next(model.parameters()).dtype}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Model from Hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using Model Config" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can fetch some registered configs available in model hub." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TransformerDecoderModel(\n", " model_dim=4096\n", " (decoder_frontend): TransformerEmbeddingFrontend(\n", " model_dim=4096\n", " (embed): StandardEmbedding(num_embeddings=128256, embedding_dim=4096)\n", " (pos_encoder): None\n", " (layer_norm): None\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (decoder): StandardTransformerDecoder(\n", " model_dim=4096, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE\n", " (layers): ModuleList(\n", " (0-31): 32 x StandardTransformerDecoderLayer(\n", " model_dim=4096, norm_order=PRE\n", " (self_attn_layer_norm): RMSNorm(normalized_shape=(4096,), eps=1E-05, elementwise_affine=True)\n", " (self_attn): StandardMultiheadAttention(\n", " num_heads=32, model_dim=4096, num_key_value_heads=8\n", " (q_proj): Linear(input_dim=4096, output_dim=4096, bias=False, init_fn=init_qkv_projection)\n", " (k_proj): Linear(input_dim=4096, output_dim=1024, bias=False, init_fn=init_qkv_projection)\n", " (v_proj): Linear(input_dim=4096, output_dim=1024, bias=False, init_fn=init_qkv_projection)\n", " (pos_encoder): RotaryEncoder(encoding_dim=128, max_seq_len=131072)\n", " (sdpa): TorchSDPA(attn_dropout_p=0.1)\n", " (output_proj): Linear(input_dim=4096, output_dim=4096, bias=False, init_fn=init_output_projection)\n", " )\n", " (self_attn_norm): None\n", " (self_attn_dropout): None\n", " (self_attn_residual): StandardResidualConnect()\n", " (encoder_decoder_attn): None\n", " (encoder_decoder_attn_dropout): None\n", " (encoder_decoder_attn_residual): None\n", " (encoder_decoder_attn_layer_norm): None\n", " (ffn_layer_norm): RMSNorm(normalized_shape=(4096,), eps=1E-05, elementwise_affine=True)\n", " (ffn): GLUFeedForwardNetwork(\n", " model_dim=4096, inner_dim_scale=0.666667, inner_dim_to_multiple=1024\n", " (gate_proj): Linear(input_dim=4096, output_dim=14336, bias=False)\n", " (gate_activation): SiLU()\n", " (inner_proj): Linear(input_dim=4096, output_dim=14336, bias=False)\n", " (inner_dropout): Dropout(p=0.1, inplace=False)\n", " (output_proj): Linear(input_dim=14336, output_dim=4096, bias=False)\n", " )\n", " (ffn_dropout): None\n", " (ffn_residual): StandardResidualConnect()\n", " )\n", " )\n", " (layer_norm): RMSNorm(normalized_shape=(4096,), eps=1E-05, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (final_proj): Linear(input_dim=4096, output_dim=128256, bias=False, init_fn=init_final_projection)\n", ")" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from fairseq2.models.llama import get_llama_model_hub, create_llama_model\n", "\n", "model_hub = get_llama_model_hub()\n", "model_config = model_hub.load_config(\n", " \"llama3_1_8b_instruct\"\n", ") # use llama3.1 8b preset as an example\n", "\n", "llama_model = create_llama_model(model_config)\n", "llama_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Directly Using Registered Model Name\n", "\n", "To check what are the registered models, we can leverage the `asset_store` in our **runtime context**, which provides a centralized way to access global resources and services throughout the codebase.\n", "\n", "The `asset_store` is a key component that manages model assets and their configurations.\n", "\n", "The runtime context is particularly important for fairseq2's extensibility:\n", "1. It allows for registering custom models, configs, assets etc.\n", "2. It provides a unified interface for accessing these resources\n", "3. It can be customized to support different backends or storage systems" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from fairseq2.context import get_runtime_context\n", "\n", "context = get_runtime_context()\n", "asset_store = context.asset_store" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['llama3_1_8b@',\n", " 'llama3_1_8b_instruct@',\n", " 'llama3_1_70b@',\n", " 'llama3_1_70b_instruct@',\n", " 'llama3_1_8b@cluster0',\n", " 'llama3_1_8b@cluster3',\n", " 'llama3_1_8b_instruct@cluster1',\n", " 'llama3_1_8b_instruct@cluster0',\n", " 'llama3_1_8b_instruct@cluster3',\n", " 'llama3_1_70b@cluster0',\n", " 'llama3_1_70b@cluster3',\n", " 'llama3_1_70b_instruct@cluster1',\n", " 'llama3_1_70b_instruct@cluster0',\n", " 'llama3_1_70b_instruct@cluster3',\n", " 'llama3_1_8b@cluster2',\n", " 'llama3_1_8b@cluster4',\n", " 'llama3_1_8b_instruct@cluster2',\n", " 'llama3_1_8b_instruct@cluster4',\n", " 'llama3_1_70b@cluster2',\n", " 'llama3_1_70b@cluster4',\n", " 'llama3_1_70b_instruct@cluster2',\n", " 'llama3_1_70b_instruct@cluster4']" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[asset for asset in asset_store.retrieve_names() if \"llama3_1\" in asset]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loading pretrained model can also be done directly from the hub." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TransformerDecoderModel(\n", " model_dim=2048\n", " (decoder_frontend): TransformerEmbeddingFrontend(\n", " model_dim=2048\n", " (embed): StandardEmbedding(num_embeddings=128256, embedding_dim=2048, init_fn=init_embed)\n", " (pos_encoder): None\n", " (layer_norm): None\n", " (dropout): None\n", " )\n", " (decoder): StandardTransformerDecoder(\n", " model_dim=2048, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE\n", " (layers): ModuleList(\n", " (0-15): 16 x StandardTransformerDecoderLayer(\n", " model_dim=2048, norm_order=PRE\n", " (self_attn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)\n", " (self_attn): StandardMultiheadAttention(\n", " num_heads=32, model_dim=2048, num_key_value_heads=8\n", " (q_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)\n", " (k_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)\n", " (v_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)\n", " (pos_encoder): RotaryEncoder(encoding_dim=64, max_seq_len=131072)\n", " (sdpa): TorchSDPA(attn_dropout_p=0)\n", " (output_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)\n", " )\n", " (self_attn_norm): None\n", " (self_attn_dropout): None\n", " (self_attn_residual): StandardResidualConnect()\n", " (encoder_decoder_attn): None\n", " (encoder_decoder_attn_dropout): None\n", " (encoder_decoder_attn_residual): None\n", " (encoder_decoder_attn_layer_norm): None\n", " (ffn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)\n", " (ffn): GLUFeedForwardNetwork(\n", " model_dim=2048, inner_dim_scale=0.666667, inner_dim_to_multiple=256\n", " (gate_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)\n", " (gate_activation): SiLU()\n", " (inner_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)\n", " (inner_dropout): None\n", " (output_proj): Linear(input_dim=8192, output_dim=2048, bias=False, init_fn=init_projection)\n", " )\n", " (ffn_dropout): None\n", " (ffn_residual): StandardResidualConnect()\n", " )\n", " )\n", " (layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)\n", " (dropout): None\n", " )\n", " (final_proj): TiedProjection(input_dim=2048, output_dim=128256)\n", ")" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from fairseq2.models.llama import get_llama_model_hub\n", "\n", "model_hub = get_llama_model_hub()\n", "# Load a pre-trained model from the hub\n", "model = model_hub.load(\n", " \"llama3_2_1b\"\n", ") # here llama3_2_1b needs to be a registered asset card\n", "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using Model Card\n", "\n", "We can also directly load model from model card." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'base': 'llama3', 'model_arch': 'llama3_2_1b', '__base_path__': PosixPath('/fsx-checkpoints/yaoj/envs/fs2_nightly_pt25_cu121/conda/lib/python3.10/site-packages/fairseq2_ext/cards/models'), '__source__': 'package:fairseq2_ext.cards', 'checkpoint': '/fsx-ram/shared/Llama-3.2-1B/original/consolidated.00.pth', 'name': 'llama3_2_1b'}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_card = asset_store.retrieve_card(\"llama3_2_1b\")\n", "model_card" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TransformerDecoderModel(\n", " model_dim=2048\n", " (decoder_frontend): TransformerEmbeddingFrontend(\n", " model_dim=2048\n", " (embed): StandardEmbedding(num_embeddings=128256, embedding_dim=2048, init_fn=init_embed)\n", " (pos_encoder): None\n", " (layer_norm): None\n", " (dropout): None\n", " )\n", " (decoder): StandardTransformerDecoder(\n", " model_dim=2048, self_attn_mask_factory=CausalAttentionMaskFactory(), norm_order=PRE\n", " (layers): ModuleList(\n", " (0-15): 16 x StandardTransformerDecoderLayer(\n", " model_dim=2048, norm_order=PRE\n", " (self_attn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)\n", " (self_attn): StandardMultiheadAttention(\n", " num_heads=32, model_dim=2048, num_key_value_heads=8\n", " (q_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)\n", " (k_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)\n", " (v_proj): Linear(input_dim=2048, output_dim=512, bias=False, init_fn=init_projection)\n", " (pos_encoder): RotaryEncoder(encoding_dim=64, max_seq_len=131072)\n", " (sdpa): TorchSDPA(attn_dropout_p=0)\n", " (output_proj): Linear(input_dim=2048, output_dim=2048, bias=False, init_fn=init_projection)\n", " )\n", " (self_attn_norm): None\n", " (self_attn_dropout): None\n", " (self_attn_residual): StandardResidualConnect()\n", " (encoder_decoder_attn): None\n", " (encoder_decoder_attn_dropout): None\n", " (encoder_decoder_attn_residual): None\n", " (encoder_decoder_attn_layer_norm): None\n", " (ffn_layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)\n", " (ffn): GLUFeedForwardNetwork(\n", " model_dim=2048, inner_dim_scale=0.666667, inner_dim_to_multiple=256\n", " (gate_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)\n", " (gate_activation): SiLU()\n", " (inner_proj): Linear(input_dim=2048, output_dim=8192, bias=False, init_fn=init_projection)\n", " (inner_dropout): None\n", " (output_proj): Linear(input_dim=8192, output_dim=2048, bias=False, init_fn=init_projection)\n", " )\n", " (ffn_dropout): None\n", " (ffn_residual): StandardResidualConnect()\n", " )\n", " )\n", " (layer_norm): RMSNorm(normalized_shape=(2048,), eps=1E-05, elementwise_affine=True, impl=torch)\n", " (dropout): None\n", " )\n", " (final_proj): TiedProjection(input_dim=2048, output_dim=128256)\n", ")" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "llama_model = model = model_hub.load(model_card)\n", "llama_model" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }