# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import Generic, TypeVar, cast, final
from fairseq2.assets import AssetCard, AssetCardError, AssetNotFoundError, AssetStore
from fairseq2.data.tokenizers.family import (
TokenizerFamily,
TokenizerFamilyNotKnownError,
)
from fairseq2.data.tokenizers.ref import resolve_tokenizer_reference
from fairseq2.data.tokenizers.tokenizer import Tokenizer
from fairseq2.device import CPU
from fairseq2.error import InternalError
from fairseq2.gang import Gangs, create_fake_gangs
from fairseq2.runtime.dependency import get_dependency_resolver
from fairseq2.runtime.lookup import Lookup
from fairseq2.utils.warn import _warn_progress_deprecated
TokenizerT = TypeVar("TokenizerT", bound=Tokenizer)
TokenizerConfigT = TypeVar("TokenizerConfigT")
[docs]
@final
class TokenizerHub(Generic[TokenizerT, TokenizerConfigT]):
def __init__(self, family: TokenizerFamily, asset_store: AssetStore) -> None:
self._family = family
self._asset_store = asset_store
[docs]
def iter_cards(self) -> Iterator[AssetCard]:
return self._asset_store.find_cards("tokenizer_family", self._family.name)
[docs]
def get_tokenizer_config(self, card: AssetCard | str) -> TokenizerConfigT:
if isinstance(card, str):
name = card
try:
card = self._asset_store.retrieve_card(name)
except AssetNotFoundError:
raise TokenizerNotKnownError(name) from None
else:
name = card.name
card = resolve_tokenizer_reference(self._asset_store, card)
family_name = card.field("tokenizer_family").as_(str)
if family_name != self._family.name:
msg = f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead."
raise AssetCardError(name, msg)
config = self._family.get_tokenizer_config(card)
return cast(TokenizerConfigT, config)
[docs]
def load_tokenizer(
self,
card: AssetCard | str,
*,
gangs: Gangs | None = None,
config: TokenizerConfigT | None = None,
progress: bool | None = None,
) -> TokenizerT:
_warn_progress_deprecated(progress)
if gangs is None:
gangs = create_fake_gangs(CPU)
if isinstance(card, str):
name = card
try:
card = self._asset_store.retrieve_card(name)
except AssetNotFoundError:
raise TokenizerNotKnownError(name) from None
else:
name = card.name
card = resolve_tokenizer_reference(self._asset_store, card)
family_name = card.field("tokenizer_family").as_(str)
if family_name != self._family.name:
msg = f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead."
raise AssetCardError(name, msg)
tokenizer = self._family.load_tokenizer(card, gangs, config)
return cast(TokenizerT, tokenizer)
[docs]
def load_custom_tokenizer(
self, path: Path, config: TokenizerConfigT, *, gangs: Gangs | None = None
) -> TokenizerT:
if gangs is None:
gangs = create_fake_gangs(CPU)
tokenizer = self._family.load_custom_tokenizer(path, config, gangs)
return cast(TokenizerT, tokenizer)
[docs]
@final
class TokenizerHubAccessor(Generic[TokenizerT, TokenizerConfigT]):
def __init__(
self,
family_name: str,
kls: type[TokenizerT],
config_kls: type[TokenizerConfigT],
) -> None:
self._family_name = family_name
self._kls = kls
self._config_kls = config_kls
def __call__(self) -> TokenizerHub[TokenizerT, TokenizerConfigT]:
resolver = get_dependency_resolver()
asset_store = resolver.resolve(AssetStore)
name = self._family_name
family = resolver.maybe_resolve(TokenizerFamily, key=name)
if family is None:
raise TokenizerFamilyNotKnownError(name)
if not issubclass(family.kls, self._kls):
raise InternalError(
f"`kls` is `{self._kls}`, but the type of the {name} tokenizer family is `{family.kls}`."
)
if not issubclass(family.config_kls, self._config_kls):
raise InternalError(
f"`config_kls` is `{self._config_kls}`, but the configuration type of the {name} tokenizer family is `{family.config_kls}`."
)
return TokenizerHub(family, asset_store)
class TokenizerNotKnownError(Exception):
def __init__(self, name: str) -> None:
super().__init__(f"{name} is not a known tokenizer.")
self.name = name
[docs]
def load_tokenizer(
card: AssetCard | str,
*,
gangs: Gangs | None = None,
config: object | None = None,
progress: bool | None = None,
) -> Tokenizer:
_warn_progress_deprecated(progress)
resolver = get_dependency_resolver()
global_loader = resolver.resolve(GlobalTokenizerLoader)
return global_loader.load(card, gangs, config)
@final
class GlobalTokenizerLoader:
def __init__(
self, asset_store: AssetStore, families: Lookup[TokenizerFamily]
) -> None:
self._asset_store = asset_store
self._families = families
def load(
self, card: AssetCard | str, gangs: Gangs | None, config: object | None
) -> Tokenizer:
if gangs is None:
gangs = create_fake_gangs(CPU)
if isinstance(card, str):
name = card
try:
card = self._asset_store.retrieve_card(name)
except AssetNotFoundError:
raise TokenizerNotKnownError(name) from None
else:
name = card.name
card = resolve_tokenizer_reference(self._asset_store, card)
family_name = card.field("tokenizer_family").as_(str)
family = self._families.maybe_get(family_name)
if family is None:
msg = f"family field of the {name} asset card is expected to be a supported tokenizer family, but is {family_name} instead."
raise AssetCardError(name, msg)
return family.load_tokenizer(card, gangs, config)