# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import Generic, TypeVar, cast, final, overload
import torch
from torch import Tensor
from torch.nn import Module
from fairseq2.assets import AssetCard, AssetCardError, AssetNotFoundError, AssetStore
from fairseq2.data_type import DataType
from fairseq2.device import CPU, Device, get_current_device
from fairseq2.error import InternalError
from fairseq2.gang import Gangs, create_fake_gangs
from fairseq2.models.family import ModelFamily
from fairseq2.runtime.dependency import get_dependency_resolver
from fairseq2.runtime.lookup import Lookup
from fairseq2.utils.warn import _warn_deprecated
ModelT = TypeVar("ModelT", bound=Module)
ModelConfigT = TypeVar("ModelConfigT")
[docs]
@final
class ModelHub(Generic[ModelT, ModelConfigT]):
"""
Provides a high-level interface for loading and creating models from a
specific model family.
This class serves as the primary entry point for working with models of a
particular family (e.g., LLaMA, Qwen, etc.). It handles model discovery,
configuration loading, and model instantiation.
"""
def __init__(self, family: ModelFamily, asset_store: AssetStore) -> None:
self._family = family
self._asset_store = asset_store
[docs]
def iter_cards(self) -> Iterator[AssetCard]:
"""
Iterates over all asset cards belonging to this model family.
.. code:: python
from fairseq2.models.qwen import get_qwen_model_hub
# List all available Qwen models.
for card in get_qwen_model_hub().iter_cards():
print(f"Model: {card.name}")
"""
return self._asset_store.find_cards("model_family", self._family.name)
[docs]
def get_archs(self) -> set[str]:
"""
Returns the set of supported model architectures in this family.
.. code:: python
from fairseq2.models.qwen import get_qwen_model_hub
# List all available Qwen architectures.
for arch in get_qwen_model_hub().get_archs():
print(f"Architecture: {arch}")
"""
return self._family.get_archs()
[docs]
def get_arch_config(self, arch: str) -> ModelConfigT:
"""
Returns the configuration for the specified model architecture.
.. code:: python
from fairseq2.models.qwen import get_qwen_model_hub
config = get_qwen_model_hub().get_arch_config("qwen25_7b")
print(config)
:raises ModelArchitectureNotKnownError: If ``arch`` is not a known
architecture in this family.
"""
config = self.maybe_get_arch_config(arch)
if config is None:
raise ModelArchitectureNotKnownError(arch, self._family.name)
return config
[docs]
def maybe_get_arch_config(self, arch: str) -> ModelConfigT | None:
"""
Returns the configuration for the specified model architecture, or
``None`` if not known.
"""
config = self._family.maybe_get_arch_config(arch)
return cast(ModelConfigT | None, config)
[docs]
def get_model_config(self, card: AssetCard | str) -> ModelConfigT:
"""
Returns the model configuration from an asset card.
This method loads the base architecture configuration and applies any
model-specific overrides specified in the asset card.
As a convenience, this method also accepts an asset name instead of an
asset card.
.. code:: python
from fairseq2.assets import get_asset_store
from fairseq2.models.qwen import QwenConfig, get_qwen_model_hub
card = get_asset_store().retrieve_card("qwen25_7b_instruct")
qwen_config = get_qwen_model_hub().get_model_config(card)
# As a convenience, the card can be omitted and the model name can
# be passed directly to `get_model_config()`:
qwen_config = get_qwen_model_hub().get_model_config("qwen25_7b_instruct")
print(qwen_config)
:raises ModelNotKnownError: If ``card`` is a string and no asset card
with that name exists.
:raises AssetCardError: If the asset card's model family does not match
this hub's family.
"""
if isinstance(card, str):
name = card
try:
card = self._asset_store.retrieve_card(name)
except AssetNotFoundError:
raise ModelNotKnownError(name) from None
else:
name = card.name
family_name = card.field("model_family").as_(str)
if family_name != self._family.name:
msg = f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead."
raise AssetCardError(name, msg)
config = self._family.get_model_config(card)
return cast(ModelConfigT, config)
@overload
def create_new_model(
self,
config: ModelConfigT,
*,
gangs: Gangs | None = None,
dtype: DataType | None = None,
meta: bool = False,
) -> ModelT: ...
@overload
def create_new_model(
self,
config: ModelConfigT,
*,
device: Device | None = None,
dtype: DataType | None = None,
meta: bool = False,
) -> ModelT: ...
[docs]
def create_new_model(
self,
config: ModelConfigT,
*,
gangs: Gangs | None = None,
device: Device | None = None,
dtype: DataType | None = None,
meta: bool = False,
) -> ModelT:
"""
Creates a new model instance with the specified configuration.
This method creates a fresh model without loading any pretrained weights.
The model will be initialized with random parameters according to the
architecture's default initialization scheme.
If ``gangs`` is provided, it will be used to apply parallelism (i.e.
model parallelism) to the initialized model. If the model family does
not support a certain parallelism strategy, that strategy will be
ignored. For instance if ``gangs.tp.size > 1``, but the model does not
support tensor parallelism, the model will be instantiated with regular
attention and feed-forward network blocks. If ``None``, the whole model
will be initialized without any parallelism.
If ``device`` is provided, the model will be created on the specified
device; otherwise, the device returned from :func:`torch.get_default_device`
will be used. Note that ``device`` and ``gangs`` cannot be provided
together. If ``gangs`` is provided, ``gangs.root.device`` will be used.
If ``dtype`` is provided, it will be used as the default data type of
the model parameters and buffers; otherwise, the data type returned from
:func:`torch.get_default_dtype` will be used.
If ``meta`` is ``True``, the model will be created on the meta device
for memory-efficient initialization. Only supported if the model family
supports meta device.
.. code:: python
from fairseq2.models.qwen import QwenConfig, get_qwen_model_hub
# Use the default Qwen configuration except the number of
# decoder layers.
config = QwenConfig(num_layers=16)
qwen_model = get_qwen_model_hub().create_new_model(config)
:raises ValueError: If both ``gangs`` and ``device`` are provided.
:raises NotSupportedError: If ``meta`` is ``True`` but the model family
doesn't support meta device.
"""
gangs = _get_effective_gangs(gangs, device)
if dtype is None:
dtype = torch.get_default_dtype()
model = self._family.create_new_model(config, gangs, dtype, meta)
return cast(ModelT, model)
@overload
def load_model(
self,
card: AssetCard | str,
*,
gangs: Gangs | None = None,
dtype: DataType | None = None,
config: ModelConfigT | None = None,
mmap: bool = False,
progress: bool = True,
) -> ModelT: ...
@overload
def load_model(
self,
card: AssetCard | str,
*,
device: Device | None = None,
dtype: DataType | None = None,
config: ModelConfigT | None = None,
mmap: bool = False,
progress: bool = True,
) -> ModelT: ...
[docs]
def load_model(
self,
card: AssetCard | str,
*,
gangs: Gangs | None = None,
device: Device | None = None,
dtype: DataType | None = None,
config: ModelConfigT | None = None,
mmap: bool = False,
progress: bool = True,
) -> ModelT:
"""
Loads a pretrained model from an asset card.
This method downloads the model checkpoint (if necessary) and loads the
pretrained weights into a model instance. The model architecture and
configuration are determined from the asset card metadata.
As a convenience, this method also accepts an asset name instead of an
asset card.
If ``gangs`` is provided, it will be used to apply parallelism (i.e.
model parallelism) to the initialized model. If the model family does
not support a certain parallelism strategy, that strategy will be
ignored. For instance if ``gangs.tp.size > 1``, but the model does not
support tensor parallelism, the model will be instantiated with regular
attention and feed-forward network blocks. If ``None``, the whole model
will be initialized without any parallelism.
If ``device`` is provided, the model will be created on the specified
device; otherwise, the device returned from :func:`torch.get_default_device`
will be used. Note that ``device`` and ``gangs`` cannot be provided
together. If ``gangs`` is provided, ``gangs.root.device`` will be used.
If ``dtype`` is provided, it will be used as the default data type of
the model parameters and buffers; otherwise, the data type returned from
:func:`torch.get_default_dtype` will be used.
If ``config`` is provided, it overrides the default model configuration
from the asset card. If ``None``, uses the configuration specified in
the card. Typically used to perform slight adjustments to the model
configuration such as tuning dropout probabilities without changing the
architecture.
If ``mmap`` is ``True``, the model checkpoint will be memory-mapped. This
can reduce memory usage but may cause slower load times on some systems.
If ``progress`` is ``True``, displays a progress bar during model
download and loading.
.. code:: python
from fairseq2.assets import get_asset_store
from fairseq2.models.qwen import QwenConfig, get_qwen_model_hub
card = get_asset_store().retrieve_card("qwen25_7b_instruct")
qwen_model = get_qwen_model_hub().load_model(card)
# As a convenience, the card can be omitted and the model name can
# be passed directly to `load_model()`:
qwen_model = get_qwen_model_hub().load_model("qwen25_7b_instruct")
:raises ModelNotKnownError: If ``card`` is a string and no asset card
with that name exists.
:raises AssetCardError: If the asset card's model family doesn't match
this hub's family.
:raises ValueError: If both ``gangs`` and ``device`` are provided.
"""
gangs = _get_effective_gangs(gangs, device)
if isinstance(card, str):
name = card
try:
card = self._asset_store.retrieve_card(name)
except AssetNotFoundError:
raise ModelNotKnownError(name) from None
else:
name = card.name
family_name = card.field("model_family").as_(str)
if family_name != self._family.name:
msg = f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead."
raise AssetCardError(name, msg)
if dtype is None:
dtype = torch.get_default_dtype()
model = self._family.load_model(card, gangs, dtype, config, mmap, progress)
return cast(ModelT, model)
@overload
def load_custom_model(
self,
path: Path,
config: ModelConfigT,
*,
gangs: Gangs | None = None,
dtype: DataType | None = None,
mmap: bool = False,
restrict: bool | None = None,
progress: bool = True,
) -> ModelT: ...
@overload
def load_custom_model(
self,
path: Path,
config: ModelConfigT,
*,
device: Device | None = None,
dtype: DataType | None = None,
mmap: bool = False,
restrict: bool | None = None,
progress: bool = True,
) -> ModelT: ...
[docs]
def load_custom_model(
self,
path: Path,
config: ModelConfigT,
*,
gangs: Gangs | None = None,
device: Device | None = None,
dtype: DataType | None = None,
mmap: bool = False,
restrict: bool | None = None,
progress: bool = True,
) -> ModelT:
"""
Loads a model from a custom checkpoint file.
This method is useful for loading models from custom training runs or
third-party checkpoints that are not available through the asset store.
``config`` specifies the model configuration. It must match the
architecture of the saved checkpoint.
If ``gangs`` is provided, it will be used to apply parallelism (i.e.
model parallelism) to the initialized model. If the model family does
not support a certain parallelism strategy, that strategy will be
ignored. For instance if ``gangs.tp.size > 1``, but the model does not
support tensor parallelism, the model will be instantiated with regular
attention and feed-forward network blocks. If ``None``, the whole model
will be initialized without any parallelism.
If ``device`` is provided, the model will be created on the specified
device; otherwise, the device returned from :func:`torch.get_default_device`
will be used. Note that ``device`` and ``gangs`` cannot be provided
together. If ``gangs`` is provided, ``gangs.root.device`` will be used.
If ``dtype`` is provided, it will be used as the default data type of
the model parameters and buffers; otherwise, the data type returned from
:func:`torch.get_default_dtype` will be used.
If ``mmap`` is ``True``, the model checkpoint will be memory-mapped. This
can reduce memory usage but may cause slower load times on some systems.
If ``restrict`` is ``True``, pickle (if used) will be restricted to load
only tensors and types that can be safely serialized and deserialized.
If ``None``, the default restriction setting of the family will be used.
If ``progress`` is ``True``, displays a progress bar during model
download and loading.
.. code:: python
from fairseq2.models.qwen import QwenConfig, get_qwen_model_hub
checkpoint_path = ...
# The checkpoint contains a Qwen model with 16 decoder layers.
config = QwenConfig(num_layers=16)
qwen_model = get_qwen_model_hub().load_custom_model(checkpoint_path, config)
:raises ValueError: If both ``gangs`` and ``device`` are provided.
:raises FileNotFoundError: If the checkpoint file does not exist.
:raises ModelCheckpointError: If the checkpoint format is not valid or
incompatible with the model.
"""
gangs = _get_effective_gangs(gangs, device)
if dtype is None:
dtype = torch.get_default_dtype()
model = self._family.load_custom_model(
path, config, gangs, dtype, mmap, restrict, progress
)
return cast(ModelT, model)
[docs]
def iter_checkpoint(
self,
path: Path,
config: ModelConfigT,
*,
gangs: Gangs | None = None,
mmap: bool = False,
restrict: bool | None = None,
) -> Iterator[tuple[str, Tensor]]:
"""
Lazily loads parameters from the specified model checkpoint path.
Yields tensors one at a time to minimize memory usage if the underlying
checkpoint format allows it.
This method provides low-level access to checkpoint contents without
loading the full model into memory. It's useful for checkpoint inspection,
custom loading logic, or memory-efficient parameter processing.
``config`` specifies the model configuration used to determine the
expected parameter structure in the checkpoint.
If ``gangs`` is provided, it is used to determine the distributed target
configuration and to shard yielded parameters accordingly. If ``None``,
no sharding will be performed and full parameters will be yielded.
If ``mmap`` is ``True``, the checkpoint will be memory-mapped. This can
reduce memory usage but may cause slower load times on some systems.
If ``restrict`` is ``True``, pickle (if used) will be restricted to load
only tensors and types that can be safely serialized and deserialized.
If ``None``, the default restriction setting of the family will be used.
Yields pairs of ``(parameter name, parameter)`` for each parameter in
the checkpoint.
:raises FileNotFoundError: If the checkpoint file does not exist.
:raises ModelCheckpointError: If the checkpoint format is not valid.
"""
gangs = _get_effective_gangs(gangs, device=CPU)
return self._family.iter_checkpoint(path, config, gangs, mmap, restrict)
[docs]
@final
class ModelHubAccessor(Generic[ModelT, ModelConfigT]):
"""
Creates a :class:`ModelHub` instance when called.
This class provides a strongly-typed way to access model hubs. Its direct
use is meant for model authors rather than library users.
See ``src/fairseq2/models/llama/hub.py`` as an example.
.. code::
:caption: The use of `ModelHubAccessor` for model authors
from fairseq2.models import ModelHubAccessor
# Defined in the Python module where the model is implemented.
get_my_model_hub = ModelHubAccessor(
family_name="my_model_family", kls=MyModel, config_kls=MyModelConfig
)
# `get_my_model_hub()` is treated as a standalone function by the model
# users in other parts of the code like below:
model_config = MyModelConfig()
model = get_my_model_hub().create_new_model(model_config)
"""
def __init__(
self, family_name: str, kls: type[ModelT], config_kls: type[ModelConfigT]
) -> None:
self._family_name = family_name
self._kls = kls
self._config_kls = config_kls
def __call__(self) -> ModelHub[ModelT, ModelConfigT]:
resolver = get_dependency_resolver()
asset_store = resolver.resolve(AssetStore)
name = self._family_name
family = resolver.resolve_optional(ModelFamily, key=name)
if family is None:
raise ModelFamilyNotKnownError(name)
if not issubclass(family.kls, self._kls):
raise InternalError(
f"`kls` is `{self._kls}`, but the type of the {name} model family is `{family.kls}`."
)
if not issubclass(family.config_kls, self._config_kls):
raise InternalError(
f"`config_kls` is `{self._config_kls}`, but the configuration type of the {name} model family is `{family.config_kls}`."
)
return ModelHub(family, asset_store)
[docs]
class ModelNotKnownError(Exception):
"""Raised when a requested model name is not found in the asset store."""
def __init__(self, name: str) -> None:
super().__init__(f"{name} is not a known model.")
self.name = name
[docs]
class ModelFamilyNotKnownError(Exception):
"""Raised when a requested model family is not registered."""
def __init__(self, name: str) -> None:
super().__init__(f"{name} is not a known model family.")
self.name = name
[docs]
class ModelArchitectureNotKnownError(Exception):
"""
Raised when a requested model architecture is not supported by a model family.
"""
def __init__(self, arch: str, family: str | None = None) -> None:
"""
``family`` defaults to ``None`` due to backwards-compatibility. New code
must specify a model family when raising this error.
"""
if family is None:
_warn_deprecated(
"`ModelArchitectureNotKnownError` will require a `family` argument starting fairseq2 v0.12."
)
super().__init__(f"{arch} is not a known model architecture.")
else:
super().__init__(f"{arch} is not a known {family} model architecture.")
self.arch = arch
self.family = family
@overload
def load_model(
card: AssetCard | str,
*,
gangs: Gangs | None = None,
dtype: DataType | None = None,
config: object = None,
mmap: bool = False,
progress: bool = True,
) -> Module: ...
@overload
def load_model(
card: AssetCard | str,
*,
device: Device | None = None,
dtype: DataType | None = None,
config: object = None,
mmap: bool = False,
progress: bool = True,
) -> Module: ...
[docs]
def load_model(
card: AssetCard | str,
*,
gangs: Gangs | None = None,
device: Device | None = None,
dtype: DataType | None = None,
config: object = None,
mmap: bool = False,
progress: bool = True,
) -> Module:
"""
Loads a pretrained model from an asset card.
This function downloads the model checkpoint (if necessary) and loads the
pretrained weights into a model instance. The model architecture and
configuration are determined from the asset card metadata.
As a convenience, this method also accepts an asset name instead of an
asset card.
The difference between ``load_model`` and :meth:`ModelHub.load_model()` is
as follows:
- ``load_model`` provides a unified interface for loading models across all
model families. It determines the appropriate model family based on asset
card metadata and delegates to the family-specific loading logic.
- The tradeoff is that (1) the ``config`` parameter of ``load_model`` is not
type-safe, (2) it is possible to accidentally load an unintended model
since the function is not constrained to a specific family.
- The general recommendation is to use :meth:`ModelHub.load_model` if the
model family is known in advance, and to use ``load_model`` if the decision
about the model and its family needs to be made at runtime.
If ``gangs`` is provided, it will be used to apply parallelism (i.e.
model parallelism) to the initialized model. If the model family does
not support a certain parallelism strategy, that strategy will be
ignored. For instance if ``gangs.tp.size > 1``, but the model does not
support tensor parallelism, the model will be instantiated with regular
attention and feed-forward network blocks. If ``None``, the whole model
will be initialized without any parallelism.
If ``device`` is provided, the model will be created on the specified
device; otherwise, the device returned from :func:`torch.get_default_device`
will be used. Note that ``device`` and ``gangs`` cannot be provided
together. If ``gangs`` is provided, ``gangs.root.device`` will be used.
If ``dtype`` is provided, it will be used as the default data type of
the model parameters and buffers; otherwise, the data type returned from
:func:`torch.get_default_dtype` will be used.
If ``config`` is provided, it overrides the default model configuration
from the asset card. If ``None``, uses the configuration specified in
the card. Typically used to perform slight adjustments to the model
configuration such as tuning dropout probabilities without changing the
architecture.
If ``mmap`` is ``True``, the model checkpoint will be memory-mapped. This
can reduce memory usage but may cause slower load times on some systems.
If ``progress`` is ``True``, displays a progress bar during model
download and loading.
.. code:: python
from fairseq2.assets import get_asset_store
from fairseq2.models.qwen import load_model
card = get_asset_store().retrieve_card("qwen25_7b_instruct")
qwen_model = load_model(card)
# As a convenience, the card can be omitted and the model name can
# be passed directly to `load_model()`:
wav2vec2_model = load_model("wav2vec2_asr_base_10h")
:raises ModelNotKnownError: If ``card`` is a string and no asset card
with that name exists.
:raises AssetCardError: If the asset card's model family doesn't match
this hub's family.
:raises ValueError: If both ``gangs`` and ``device`` are provided.
"""
resolver = get_dependency_resolver()
global_loader = resolver.resolve(GlobalModelLoader)
return global_loader.load(card, gangs, device, dtype, config, mmap, progress)
@final
class GlobalModelLoader:
"""
A global model loader that can load models from any registered model family.
This class is used internally by the :func:`load_model` function to provide
a unified interface for loading models across all model families. It resolves
the appropriate model family based on asset card metadata and delegates to
the family-specific loading logic.
"""
def __init__(self, asset_store: AssetStore, families: Lookup[ModelFamily]) -> None:
self._asset_store = asset_store
self._families = families
def load(
self,
card: AssetCard | str,
gangs: Gangs | None,
device: Device | None,
dtype: DataType | None,
config: object | None,
mmap: bool,
progress: bool,
) -> Module:
"""See :func:`load_model`."""
gangs = _get_effective_gangs(gangs, device)
if isinstance(card, str):
name = card
try:
card = self._asset_store.retrieve_card(name)
except AssetNotFoundError:
raise ModelNotKnownError(name) from None
else:
name = card.name
family_name = card.field("model_family").as_(str)
family = self._families.maybe_get(family_name)
if family is None:
msg = f"family field of the {name} asset card is expected to be a supported model family, but is {family_name} instead."
raise AssetCardError(name, msg)
if dtype is None:
dtype = torch.get_default_dtype()
return family.load_model(card, gangs, dtype, config, mmap, progress)
def _get_effective_gangs(gangs: Gangs | None, device: Device | None) -> Gangs:
if gangs is not None:
if device is not None:
raise ValueError(
"`gangs` and `device` must not be specified at the same time."
)
return gangs
if device is None:
device = get_current_device()
if device.type == "meta":
raise ValueError("`device` must be a real device.")
return create_fake_gangs(device)