# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import Generic, TypeVar, cast, final
import torch
from torch import Tensor
from torch.nn import Module
from fairseq2.assets import AssetCard, AssetCardError, AssetNotFoundError, AssetStore
from fairseq2.data_type import DataType
from fairseq2.device import Device
from fairseq2.error import InternalError
from fairseq2.gang import Gangs, create_fake_gangs
from fairseq2.models.family import ModelFamily
from fairseq2.runtime.dependency import get_dependency_resolver
from fairseq2.runtime.lookup import Lookup
ModelT = TypeVar("ModelT", bound=Module)
ModelConfigT = TypeVar("ModelConfigT")
[docs]
@final
class ModelHub(Generic[ModelT, ModelConfigT]):
def __init__(self, family: ModelFamily, asset_store: AssetStore) -> None:
self._family = family
self._asset_store = asset_store
def iter_cards(self) -> Iterator[AssetCard]:
return self._asset_store.find_cards("model_family", self._family.name)
def get_archs(self) -> set[str]:
return self._family.get_archs()
def get_arch_config(self, arch: str) -> ModelConfigT:
config = self.maybe_get_arch_config(arch)
if config is None:
raise ModelArchitectureNotKnownError(arch)
return config
def maybe_get_arch_config(self, arch: str) -> ModelConfigT | None:
config = self._family.maybe_get_arch_config(arch)
return cast(ModelConfigT | None, config)
def get_model_config(self, card: AssetCard | str) -> ModelConfigT:
if isinstance(card, str):
name = card
try:
card = self._asset_store.retrieve_card(name)
except AssetNotFoundError:
raise ModelNotKnownError(name) from None
else:
name = card.name
family_name = card.field("model_family").as_(str)
if family_name != self._family.name:
msg = f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead."
raise AssetCardError(name, msg)
config = self._family.get_model_config(card)
return cast(ModelConfigT, config)
def create_new_model(
self,
config: ModelConfigT,
*,
gangs: Gangs | None = None,
device: Device | None = None,
dtype: DataType | None = None,
meta: bool = False,
) -> ModelT:
gangs = _get_effective_gangs(gangs, device)
if dtype is None:
dtype = torch.get_default_dtype()
model = self._family.create_new_model(config, gangs, dtype, meta)
return cast(ModelT, model)
def load_model(
self,
card: AssetCard | str,
*,
gangs: Gangs | None = None,
device: Device | None = None,
dtype: DataType | None = None,
config: ModelConfigT | None = None,
mmap: bool = False,
progress: bool = True,
) -> ModelT:
gangs = _get_effective_gangs(gangs, device)
if isinstance(card, str):
name = card
try:
card = self._asset_store.retrieve_card(name)
except AssetNotFoundError:
raise ModelNotKnownError(name) from None
else:
name = card.name
family_name = card.field("model_family").as_(str)
if family_name != self._family.name:
msg = f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead."
raise AssetCardError(name, msg)
if dtype is None:
dtype = torch.get_default_dtype()
model = self._family.load_model(card, gangs, dtype, config, mmap, progress)
return cast(ModelT, model)
def load_custom_model(
self,
path: Path,
config: ModelConfigT,
*,
gangs: Gangs | None = None,
device: Device | None = None,
dtype: DataType | None = None,
mmap: bool = False,
restrict: bool | None = None,
progress: bool = True,
) -> ModelT:
gangs = _get_effective_gangs(gangs, device)
if dtype is None:
dtype = torch.get_default_dtype()
model = self._family.load_custom_model(
path, config, gangs, dtype, mmap, restrict, progress
)
return cast(ModelT, model)
def iter_checkpoint(
self,
path: Path,
config: ModelConfigT,
*,
gangs: Gangs | None = None,
mmap: bool = False,
restrict: bool | None = None,
) -> Iterator[tuple[str, Tensor]]:
gangs = _get_effective_gangs(gangs, device=None)
return self._family.iter_checkpoint(path, config, gangs, mmap, restrict)
[docs]
@final
class ModelHubAccessor(Generic[ModelT, ModelConfigT]):
def __init__(
self, family_name: str, kls: type[ModelT], config_kls: type[ModelConfigT]
) -> None:
self._family_name = family_name
self._kls = kls
self._config_kls = config_kls
def __call__(self) -> ModelHub[ModelT, ModelConfigT]:
resolver = get_dependency_resolver()
asset_store = resolver.resolve(AssetStore)
name = self._family_name
family = resolver.resolve_optional(ModelFamily, key=name)
if family is None:
raise ModelFamilyNotKnownError(name)
if not issubclass(family.kls, self._kls):
raise InternalError(
f"`kls` is `{self._kls}`, but the type of the {name} model family is `{family.kls}`."
)
if not issubclass(family.config_kls, self._config_kls):
raise InternalError(
f"`config_kls` is `{self._config_kls}`, but the configuration type of the {name} model family is `{family.config_kls}`."
)
return ModelHub(family, asset_store)
[docs]
class ModelNotKnownError(Exception):
def __init__(self, name: str) -> None:
super().__init__(f"{name} is not a known model.")
self.name = name
[docs]
class ModelFamilyNotKnownError(Exception):
def __init__(self, name: str) -> None:
super().__init__(f"{name} is not a known model family.")
self.name = name
[docs]
class ModelArchitectureNotKnownError(Exception):
def __init__(self, arch: str) -> None:
super().__init__(f"{arch} is not a known model architecture.")
self.arch = arch
[docs]
def load_model(
card: AssetCard | str,
*,
gangs: Gangs | None = None,
device: Device | None = None,
dtype: DataType | None = None,
config: object = None,
mmap: bool = False,
progress: bool = True,
) -> Module:
resolver = get_dependency_resolver()
global_loader = resolver.resolve(GlobalModelLoader)
return global_loader.load(card, gangs, device, dtype, config, mmap, progress)
@final
class GlobalModelLoader:
def __init__(self, asset_store: AssetStore, families: Lookup[ModelFamily]) -> None:
self._asset_store = asset_store
self._families = families
def load(
self,
card: AssetCard | str,
gangs: Gangs | None,
device: Device | None,
dtype: DataType | None,
config: object | None,
mmap: bool,
progress: bool,
) -> Module:
gangs = _get_effective_gangs(gangs, device)
if isinstance(card, str):
name = card
try:
card = self._asset_store.retrieve_card(name)
except AssetNotFoundError:
raise ModelNotKnownError(name) from None
else:
name = card.name
family_name = card.field("model_family").as_(str)
family = self._families.maybe_get(family_name)
if family is None:
msg = f"family field of the {name} asset card is expected to be a supported model family, but is {family_name} instead."
raise AssetCardError(name, msg)
if dtype is None:
dtype = torch.get_default_dtype()
return family.load_model(card, gangs, dtype, config, mmap, progress)
def _get_effective_gangs(gangs: Gangs | None, device: Device | None) -> Gangs:
if gangs is not None and device is not None:
raise ValueError("`gangs` and `device` must not be specified at the same time.")
if device is not None:
if device.type == "meta":
raise ValueError("`device` must be a real device.")
return create_fake_gangs(device)
if gangs is None:
device = torch.get_default_device()
return create_fake_gangs(device)
if gangs.root.device.type == "meta":
raise ValueError("`gangs.root` must be on a real device.")
return gangs