Source code for fairseq2.data.tokenizers.hub

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from collections.abc import Iterator
from pathlib import Path
from typing import Generic, TypeVar, cast, final

from fairseq2.assets import AssetCard, AssetCardError, AssetNotFoundError, AssetStore
from fairseq2.data.tokenizers.family import TokenizerFamily
from fairseq2.data.tokenizers.ref import resolve_tokenizer_reference
from fairseq2.data.tokenizers.tokenizer import Tokenizer
from fairseq2.error import InternalError
from fairseq2.runtime.dependency import get_dependency_resolver
from fairseq2.runtime.lookup import Lookup

TokenizerT = TypeVar("TokenizerT", bound=Tokenizer)

TokenizerConfigT = TypeVar("TokenizerConfigT")


[docs] @final class TokenizerHub(Generic[TokenizerT, TokenizerConfigT]): def __init__(self, family: TokenizerFamily, asset_store: AssetStore) -> None: self._family = family self._asset_store = asset_store
[docs] def iter_cards(self) -> Iterator[AssetCard]: return self._asset_store.find_cards("tokenizer_family", self._family.name)
[docs] def get_tokenizer_config(self, card: AssetCard | str) -> TokenizerConfigT: if isinstance(card, str): name = card try: card = self._asset_store.retrieve_card(name) except AssetNotFoundError: raise TokenizerNotKnownError(name) from None else: name = card.name card = resolve_tokenizer_reference(self._asset_store, card) family_name = card.field("tokenizer_family").as_(str) if family_name != self._family.name: msg = f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead." raise AssetCardError(name, msg) config = self._family.get_tokenizer_config(card) return cast(TokenizerConfigT, config)
[docs] def load_tokenizer( self, card: AssetCard | str, *, config: TokenizerConfigT | None = None, progress: bool = True, ) -> TokenizerT: if isinstance(card, str): name = card try: card = self._asset_store.retrieve_card(name) except AssetNotFoundError: raise TokenizerNotKnownError(name) from None else: name = card.name card = resolve_tokenizer_reference(self._asset_store, card) family_name = card.field("tokenizer_family").as_(str) if family_name != self._family.name: msg = f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead." raise AssetCardError(name, msg) tokenizer = self._family.load_tokenizer(card, config, progress) return cast(TokenizerT, tokenizer)
[docs] def load_custom_tokenizer(self, path: Path, config: TokenizerConfigT) -> TokenizerT: tokenizer = self._family.load_custom_tokenizer(path, config) return cast(TokenizerT, tokenizer)
[docs] @final class TokenizerHubAccessor(Generic[TokenizerT, TokenizerConfigT]): def __init__( self, family_name: str, kls: type[TokenizerT], config_kls: type[TokenizerConfigT], ) -> None: self._family_name = family_name self._kls = kls self._config_kls = config_kls def __call__(self) -> TokenizerHub[TokenizerT, TokenizerConfigT]: resolver = get_dependency_resolver() asset_store = resolver.resolve(AssetStore) name = self._family_name family = resolver.resolve_optional(TokenizerFamily, key=name) if family is None: raise TokenizerFamilyNotKnownError(name) if not issubclass(family.kls, self._kls): raise InternalError( f"`kls` is `{self._kls}`, but the type of the {name} tokenizer family is `{family.kls}`." ) if not issubclass(family.config_kls, self._config_kls): raise InternalError( f"`config_kls` is `{self._config_kls}`, but the configuration type of the {name} tokenizer family is `{family.config_kls}`." ) return TokenizerHub(family, asset_store)
[docs] class TokenizerNotKnownError(Exception): def __init__(self, name: str) -> None: super().__init__(f"{name} is not a known tokenizer.") self.name = name
[docs] class TokenizerFamilyNotKnownError(Exception): def __init__(self, name: str) -> None: super().__init__(f"{name} is not a known tokenizer family.") self.name = name
[docs] def load_tokenizer( card: AssetCard | str, *, config: object = None, progress: bool = True ) -> Tokenizer: resolver = get_dependency_resolver() global_loader = resolver.resolve(GlobalTokenizerLoader) return global_loader.load(card, config, progress)
@final class GlobalTokenizerLoader: def __init__( self, asset_store: AssetStore, families: Lookup[TokenizerFamily] ) -> None: self._asset_store = asset_store self._families = families def load( self, card: AssetCard | str, config: object | None, progress: bool ) -> Tokenizer: if isinstance(card, str): name = card try: card = self._asset_store.retrieve_card(name) except AssetNotFoundError: raise TokenizerNotKnownError(name) from None else: name = card.name card = resolve_tokenizer_reference(self._asset_store, card) family_name = card.field("tokenizer_family").as_(str) family = self._families.maybe_get(family_name) if family is None: msg = f"family field of the {name} asset card is expected to be a supported tokenizer family, but is {family_name} instead." raise AssetCardError(name, msg) return family.load_tokenizer(card, config, progress)