# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.from__future__importannotationsfromcollections.abcimportIteratorfromtypingimportGeneric,TypeVar,cast,finalfromfairseq2.assetsimportAssetCard,AssetCardError,AssetNotFoundError,AssetStorefromfairseq2.datasets.familyimportDatasetFamilyfromfairseq2.errorimportInternalErrorfromfairseq2.runtime.dependencyimportget_dependency_resolverDatasetT=TypeVar("DatasetT")DatasetConfigT=TypeVar("DatasetConfigT")
[docs]@finalclassDatasetHub(Generic[DatasetT,DatasetConfigT]):def__init__(self,family:DatasetFamily,asset_store:AssetStore)->None:self._family=familyself._asset_store=asset_storedefiter_cards(self)->Iterator[AssetCard]:returnself._asset_store.find_cards("dataset_family",self._family.name)defget_dataset_config(self,card:AssetCard|str)->DatasetConfigT:ifisinstance(card,str):name=cardtry:card=self._asset_store.retrieve_card(name)exceptAssetNotFoundError:raiseDatasetNotKnownError(name)fromNoneelse:name=card.namefamily_name=card.field("dataset_family").as_(str)iffamily_name!=self._family.name:msg=f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead."raiseAssetCardError(name,msg)config=self._family.get_dataset_config(card)returncast(DatasetConfigT,config)defopen_dataset(self,card:AssetCard|str,*,config:DatasetConfigT|None=None)->DatasetT:ifisinstance(card,str):name=cardtry:card=self._asset_store.retrieve_card(name)exceptAssetNotFoundError:raiseDatasetNotKnownError(name)fromNoneelse:name=card.namefamily_name=card.field("dataset_family").as_(str)iffamily_name!=self._family.name:msg=f"family field of the {name} asset card is expected to be {self._family.name}, but is {family_name} instead."raiseAssetCardError(name,msg)dataset=self._family.open_dataset(card,config)returncast(DatasetT,dataset)defopen_custom_dataset(self,config:DatasetConfigT)->DatasetT:dataset=self._family.open_custom_dataset(config)returncast(DatasetT,dataset)
[docs]@finalclassDatasetHubAccessor(Generic[DatasetT,DatasetConfigT]):def__init__(self,family_name:str,kls:type[DatasetT],config_kls:type[DatasetConfigT])->None:self._family_name=family_nameself._kls=klsself._config_kls=config_klsdef__call__(self)->DatasetHub[DatasetT,DatasetConfigT]:resolver=get_dependency_resolver()asset_store=resolver.resolve(AssetStore)name=self._family_namefamily=resolver.resolve_optional(DatasetFamily,key=name)iffamilyisNone:raiseDatasetFamilyNotKnownError(name)ifnotissubclass(family.kls,self._kls):raiseInternalError(f"`kls` is `{self._kls}`, but the type of the {name} dataset family is `{family.kls}`.")ifnotissubclass(family.config_kls,self._config_kls):raiseInternalError(f"`config_kls` is `{self._config_kls}`, but the configuration type of the {name} dataset family is `{family.config_kls}`.")returnDatasetHub(family,asset_store)
[docs]classDatasetNotKnownError(Exception):def__init__(self,name:str)->None:super().__init__(f"{name} is not a known dataset.")self.name=name
[docs]classDatasetFamilyNotKnownError(Exception):def__init__(self,name:str)->None:super().__init__(f"{name} is not a know dataset family.")self.name=name