fairseq2.models.hub

The model hub provides a unified interface for working with model families in fairseq2. Each model family has its own hub that exposes methods for loading models, creating new instances, listing architectures, and more.

Quick Start

from fairseq2.models.qwen import get_qwen_model_hub

# Get the model hub for Qwen family
hub = get_qwen_model_hub()

# List available architectures
archs = hub.get_archs()
print(f"Available architectures: {archs}")

# Create a new uninitialized model
config = hub.get_arch_config("qwen25_7b")
model = hub.create_new_model(config)

# Load a model from asset card
model = hub.load_model("qwen25_7b")

# Load a model from custom checkpoint
from pathlib import Path
model = hub.load_custom_model(Path("/path/to/checkpoint.pt"), config)

Core Classes

ModelHub

final class fairseq2.models.hub.ModelHub(family: ModelFamily, asset_store: AssetStore)[source]

Bases: Generic[ModelT, ModelConfigT]

Provides a high-level interface for loading and creating models from a specific model family.

This class serves as the primary entry point for working with models of a particular family (e.g., LLaMA, Qwen, etc.). It handles model discovery, configuration loading, and model instantiation.

The main hub class that provides access to all model operations for a specific family.

Key Methods:

iter_cards() Iterator[AssetCard][source]

Iterates over all asset cards belonging to this model family.

from fairseq2.models.qwen import get_qwen_model_hub

# List all available Qwen models.
for card in get_qwen_model_hub().iter_cards():
    print(f"Model: {card.name}")
get_archs() set[str][source]

Returns the set of supported model architectures in this family.

from fairseq2.models.qwen import get_qwen_model_hub

# List all available Qwen architectures.
for arch in get_qwen_model_hub().get_archs():
    print(f"Architecture: {arch}")
get_arch_config(arch: str) ModelConfigT[source]

Returns the configuration for the specified model architecture.

from fairseq2.models.qwen import get_qwen_model_hub

config = get_qwen_model_hub().get_arch_config("qwen25_7b")

print(config)
Raises:

ModelArchitectureNotKnownError – If arch is not a known architecture in this family.

maybe_get_arch_config(arch: str) ModelConfigT | None[source]

Returns the configuration for the specified model architecture, or None if not known.

get_model_config(card: AssetCard | str) ModelConfigT[source]

Returns the model configuration from an asset card.

This method loads the base architecture configuration and applies any model-specific overrides specified in the asset card.

As a convenience, this method also accepts an asset name instead of an asset card.

from fairseq2.assets import get_asset_store
from fairseq2.models.qwen import QwenConfig, get_qwen_model_hub

card = get_asset_store().retrieve_card("qwen25_7b_instruct")

qwen_config = get_qwen_model_hub().get_model_config(card)

# As a convenience, the card can be omitted and the model name can
# be passed directly to `get_model_config()`:
qwen_config = get_qwen_model_hub().get_model_config("qwen25_7b_instruct")

print(qwen_config)
Raises:
  • ModelNotKnownError – If card is a string and no asset card with that name exists.

  • AssetCardError – If the asset card’s model family does not match this hub’s family.

create_new_model(config: ModelConfigT, *, gangs: Gangs | None = None, dtype: dtype | None = None, meta: bool = False) ModelT[source]
create_new_model(config: ModelConfigT, *, device: device | None = None, dtype: dtype | None = None, meta: bool = False) ModelT

Creates a new model instance with the specified configuration.

This method creates a fresh model without loading any pretrained weights. The model will be initialized with random parameters according to the architecture’s default initialization scheme.

If gangs is provided, it will be used to apply parallelism (i.e. model parallelism) to the initialized model. If the model family does not support a certain parallelism strategy, that strategy will be ignored. For instance if gangs.tp.size > 1, but the model does not support tensor parallelism, the model will be instantiated with regular attention and feed-forward network blocks. If None, the whole model will be initialized without any parallelism.

If device is provided, the model will be created on the specified device; otherwise, the device returned from torch.get_default_device() will be used. Note that device and gangs cannot be provided together. If gangs is provided, gangs.root.device will be used.

If dtype is provided, it will be used as the default data type of the model parameters and buffers; otherwise, the data type returned from torch.get_default_dtype() will be used.

If meta is True, the model will be created on the meta device for memory-efficient initialization. Only supported if the model family supports meta device.

from fairseq2.models.qwen import QwenConfig, get_qwen_model_hub

# Use the default Qwen configuration except the number of
# decoder layers.
config = QwenConfig(num_layers=16)

qwen_model = get_qwen_model_hub().create_new_model(config)
Raises:
  • ValueError – If both gangs and device are provided.

  • NotSupportedError – If meta is True but the model family doesn’t support meta device.

load_model(card: AssetCard | str, *, gangs: Gangs | None = None, dtype: dtype | None = None, config: ModelConfigT | None = None, mmap: bool = False, progress: bool = True) ModelT[source]
load_model(card: AssetCard | str, *, device: device | None = None, dtype: dtype | None = None, config: ModelConfigT | None = None, mmap: bool = False, progress: bool = True) ModelT

Loads a pretrained model from an asset card.

This method downloads the model checkpoint (if necessary) and loads the pretrained weights into a model instance. The model architecture and configuration are determined from the asset card metadata.

As a convenience, this method also accepts an asset name instead of an asset card.

If gangs is provided, it will be used to apply parallelism (i.e. model parallelism) to the initialized model. If the model family does not support a certain parallelism strategy, that strategy will be ignored. For instance if gangs.tp.size > 1, but the model does not support tensor parallelism, the model will be instantiated with regular attention and feed-forward network blocks. If None, the whole model will be initialized without any parallelism.

If device is provided, the model will be created on the specified device; otherwise, the device returned from torch.get_default_device() will be used. Note that device and gangs cannot be provided together. If gangs is provided, gangs.root.device will be used.

If dtype is provided, it will be used as the default data type of the model parameters and buffers; otherwise, the data type returned from torch.get_default_dtype() will be used.

If config is provided, it overrides the default model configuration from the asset card. If None, uses the configuration specified in the card. Typically used to perform slight adjustments to the model configuration such as tuning dropout probabilities without changing the architecture.

If mmap is True, the model checkpoint will be memory-mapped. This can reduce memory usage but may cause slower load times on some systems.

If progress is True, displays a progress bar during model download and loading.

from fairseq2.assets import get_asset_store
from fairseq2.models.qwen import QwenConfig, get_qwen_model_hub

card = get_asset_store().retrieve_card("qwen25_7b_instruct")

qwen_model = get_qwen_model_hub().load_model(card)

# As a convenience, the card can be omitted and the model name can
# be passed directly to `load_model()`:
qwen_model = get_qwen_model_hub().load_model("qwen25_7b_instruct")
Raises:
  • ModelNotKnownError – If card is a string and no asset card with that name exists.

  • AssetCardError – If the asset card’s model family doesn’t match this hub’s family.

  • ValueError – If both gangs and device are provided.

load_custom_model(path: Path, config: ModelConfigT, *, gangs: Gangs | None = None, dtype: dtype | None = None, mmap: bool = False, restrict: bool | None = None, progress: bool = True) ModelT[source]
load_custom_model(path: Path, config: ModelConfigT, *, device: device | None = None, dtype: dtype | None = None, mmap: bool = False, restrict: bool | None = None, progress: bool = True) ModelT

Loads a model from a custom checkpoint file.

This method is useful for loading models from custom training runs or third-party checkpoints that are not available through the asset store.

config specifies the model configuration. It must match the architecture of the saved checkpoint.

If gangs is provided, it will be used to apply parallelism (i.e. model parallelism) to the initialized model. If the model family does not support a certain parallelism strategy, that strategy will be ignored. For instance if gangs.tp.size > 1, but the model does not support tensor parallelism, the model will be instantiated with regular attention and feed-forward network blocks. If None, the whole model will be initialized without any parallelism.

If device is provided, the model will be created on the specified device; otherwise, the device returned from torch.get_default_device() will be used. Note that device and gangs cannot be provided together. If gangs is provided, gangs.root.device will be used.

If dtype is provided, it will be used as the default data type of the model parameters and buffers; otherwise, the data type returned from torch.get_default_dtype() will be used.

If mmap is True, the model checkpoint will be memory-mapped. This can reduce memory usage but may cause slower load times on some systems.

If restrict is True, pickle (if used) will be restricted to load only tensors and types that can be safely serialized and deserialized. If None, the default restriction setting of the family will be used.

If progress is True, displays a progress bar during model download and loading.

from fairseq2.models.qwen import QwenConfig, get_qwen_model_hub

checkpoint_path = ...

# The checkpoint contains a Qwen model with 16 decoder layers.
config = QwenConfig(num_layers=16)

qwen_model = get_qwen_model_hub().load_custom_model(checkpoint_path, config)
Raises:
iter_checkpoint(path: Path, config: ModelConfigT, *, gangs: Gangs | None = None, mmap: bool = False, restrict: bool | None = None) Iterator[tuple[str, Tensor]][source]

Lazily loads parameters from the specified model checkpoint path.

Yields tensors one at a time to minimize memory usage if the underlying checkpoint format allows it.

This method provides low-level access to checkpoint contents without loading the full model into memory. It’s useful for checkpoint inspection, custom loading logic, or memory-efficient parameter processing.

config specifies the model configuration used to determine the expected parameter structure in the checkpoint.

If gangs is provided, it is used to determine the distributed target configuration and to shard yielded parameters accordingly. If None, no sharding will be performed and full parameters will be yielded.

If mmap is True, the checkpoint will be memory-mapped. This can reduce memory usage but may cause slower load times on some systems.

If restrict is True, pickle (if used) will be restricted to load only tensors and types that can be safely serialized and deserialized. If None, the default restriction setting of the family will be used.

Yields pairs of (parameter name, parameter) for each parameter in the checkpoint.

Raises:

ModelHubAccessor

final class fairseq2.models.hub.ModelHubAccessor(family_name: str, kls: type[ModelT], config_kls: type[ModelConfigT])[source]

Bases: Generic[ModelT, ModelConfigT]

Creates a ModelHub instance when called.

This class provides a strongly-typed way to access model hubs. Its direct use is meant for model authors rather than library users.

See src/fairseq2/models/llama/hub.py as an example.

The use of ModelHubAccessor for model authors
from fairseq2.models import ModelHubAccessor

# Defined in the Python module where the model is implemented.
get_my_model_hub = ModelHubAccessor(
    family_name="my_model_family", kls=MyModel, config_kls=MyModelConfig
)

# `get_my_model_hub()` is treated as a standalone function by the model
# users in other parts of the code like below:
model_config = MyModelConfig()

model = get_my_model_hub().create_new_model(model_config)

Provides access to model hubs for specific families. Can be used by model implementors to create hub accessors for their model families, like fairseq2.models.qwen.hub.get_qwen_model_hub().

Global Functions

load_model

fairseq2.models.hub.load_model(card: AssetCard | str, *, gangs: Gangs | None = None, dtype: dtype | None = None, config: object = None, mmap: bool = False, progress: bool = True) Module[source]
fairseq2.models.hub.load_model(card: AssetCard | str, *, device: device | None = None, dtype: dtype | None = None, config: object = None, mmap: bool = False, progress: bool = True) Module

Loads a pretrained model from an asset card.

This function downloads the model checkpoint (if necessary) and loads the pretrained weights into a model instance. The model architecture and configuration are determined from the asset card metadata.

As a convenience, this method also accepts an asset name instead of an asset card.

The difference between load_model and ModelHub.load_model() is as follows:

  • load_model provides a unified interface for loading models across all model families. It determines the appropriate model family based on asset card metadata and delegates to the family-specific loading logic.

  • The tradeoff is that (1) the config parameter of load_model is not type-safe, (2) it is possible to accidentally load an unintended model since the function is not constrained to a specific family.

  • The general recommendation is to use ModelHub.load_model() if the model family is known in advance, and to use load_model if the decision about the model and its family needs to be made at runtime.

If gangs is provided, it will be used to apply parallelism (i.e. model parallelism) to the initialized model. If the model family does not support a certain parallelism strategy, that strategy will be ignored. For instance if gangs.tp.size > 1, but the model does not support tensor parallelism, the model will be instantiated with regular attention and feed-forward network blocks. If None, the whole model will be initialized without any parallelism.

If device is provided, the model will be created on the specified device; otherwise, the device returned from torch.get_default_device() will be used. Note that device and gangs cannot be provided together. If gangs is provided, gangs.root.device will be used.

If dtype is provided, it will be used as the default data type of the model parameters and buffers; otherwise, the data type returned from torch.get_default_dtype() will be used.

If config is provided, it overrides the default model configuration from the asset card. If None, uses the configuration specified in the card. Typically used to perform slight adjustments to the model configuration such as tuning dropout probabilities without changing the architecture.

If mmap is True, the model checkpoint will be memory-mapped. This can reduce memory usage but may cause slower load times on some systems.

If progress is True, displays a progress bar during model download and loading.

from fairseq2.assets import get_asset_store
from fairseq2.models.qwen import load_model

card = get_asset_store().retrieve_card("qwen25_7b_instruct")

qwen_model = load_model(card)

# As a convenience, the card can be omitted and the model name can
# be passed directly to `load_model()`:
wav2vec2_model = load_model("wav2vec2_asr_base_10h")
Raises:
  • ModelNotKnownError – If card is a string and no asset card with that name exists.

  • AssetCardError – If the asset card’s model family doesn’t match this hub’s family.

  • ValueError – If both gangs and device are provided.

The main function for loading models across all families. Automatically determines the appropriate model family from the asset card.

from fairseq2.models.hub import load_model

# Load any model by name
model = load_model("qwen25_7b")
model = load_model("llama3_8b")
model = load_model("mistral_7b")

Working with Model Families

Each model family provides its own hub accessor function:

Qwen Models

from fairseq2.models.qwen import get_qwen_model_hub

hub = get_qwen_model_hub()

# Available architectures
archs = hub.get_archs()  # {'qwen25_0.5b', 'qwen25_1.5b', 'qwen25_3b', ...}

# Get configuration for specific architecture
config = hub.get_arch_config("qwen25_7b")

# Create new model
model = hub.create_new_model(config)

LLaMA Models

from fairseq2.models.llama import get_llama_model_hub

hub = get_llama_model_hub()

# List available LLaMA architectures
archs = hub.get_archs()

# Load specific LLaMA model
model = hub.load_model("llama3_8b")

Mistral Models

from fairseq2.models.mistral import get_mistral_model_hub

hub = get_mistral_model_hub()
model = hub.load_model("mistral_7b")

Advanced Usage

Custom Model Loading

Load models from custom checkpoints with specific configurations:

from pathlib import Path
from fairseq2.models.qwen import get_qwen_model_hub

hub = get_qwen_model_hub()

# Get base configuration
config = hub.get_arch_config("qwen25_7b")

# Modify configuration if needed
config.max_seq_len = 32768

# Load from custom checkpoint
checkpoint_path = Path("/path/to/my/checkpoint.pt")
model = hub.load_custom_model(checkpoint_path, config)

Iterating Over Model Cards

Discover all available models in a family:

from fairseq2.models.qwen import get_qwen_model_hub

hub = get_qwen_model_hub()

# List all Qwen model cards
for card in hub.iter_cards():
    print(f"Model: {card.name}")
    print(f"  Architecture: {card.field('model_arch').as_(str)}")
    print(f"  Checkpoint: {card.field('checkpoint').as_(str)}")

Checkpoint Inspection

Iterate over checkpoint tensors without loading the full model:

from pathlib import Path
from fairseq2.models.qwen import get_qwen_model_hub

hub = get_qwen_model_hub()
config = hub.get_arch_config("qwen25_7b")

checkpoint_path = Path("/path/to/checkpoint.pt")

# Inspect checkpoint contents
for name, tensor in hub.iter_checkpoint(checkpoint_path, config):
    print(f"Parameter: {name}, Shape: {tensor.shape}")

Error Handling

Common Exceptions

exception fairseq2.models.hub.ModelNotKnownError(name: str)[source]

Bases: Exception

Raised when a requested model name is not found in the asset store.

Raised when a requested model name is not found in the asset store.

exception fairseq2.models.hub.ModelFamilyNotKnownError(name: str)[source]

Bases: Exception

Raised when a requested model family is not registered.

Raised when a model family is not registered or available.

exception fairseq2.models.hub.ModelArchitectureNotKnownError(arch: str, family: str | None = None)[source]

Bases: Exception

Raised when a requested model architecture is not supported by a model family.

family defaults to None due to backwards-compatibility. New code must specify a model family when raising this error.

Raised when a requested architecture is not available in the model family.

Example Error Handling

from fairseq2.models.hub import load_model, ModelNotKnownError, ModelArchitectureNotKnownError
from fairseq2.models.qwen import get_qwen_model_hub

try:
    model = load_model("nonexistent_model")
except ModelNotKnownError as e:
    print(f"Model not found: {e.name}")

try:
    hub = get_qwen_model_hub()
    config = hub.get_arch_config("invalid_arch")
except ModelArchitectureNotKnownError as e:
    print(f"Architecture not found: {e.arch}")
    print(f"Available architectures: {hub.get_archs()}")

See Also