Source code for fairseq2.datasets.loader

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Protocol, TypeVar, cast, final

from fairseq2.assets import (
    AssetCard,
    AssetCardError,
    AssetDownloadManager,
    AssetError,
    AssetStore,
    default_asset_download_manager,
    default_asset_store,
)

DatasetT = TypeVar("DatasetT")

DatasetT_co = TypeVar("DatasetT_co", covariant=True)


[docs] class DatasetLoader(Protocol[DatasetT_co]): """Loads datasets of type ``DatasetT```."""
[docs] def __call__( self, dataset_name_or_card: str | AssetCard, *, force: bool = False, progress: bool = True, ) -> DatasetT_co: """ :param dataset_name_or_card: The name or the asset card of the dataset to load. :param force: If ``True``, downloads the dataset even if it is already in cache. :param progress: If ``True``, displays a progress bar to stderr. """
[docs] class AbstractDatasetLoader(ABC, DatasetLoader[DatasetT]): """Provides a skeletal implementation of :class:`DatasetLoader`.""" _asset_store: AssetStore _download_manager: AssetDownloadManager def __init__( self, *, asset_store: AssetStore | None = None, download_manager: AssetDownloadManager | None = None, ) -> None: """ :param asset_store: The asset store where to check for available datasets. If ``None``, the default asset store will be used. :param download_manager: The download manager. If ``None``, the default download manager will be used. """ self._asset_store = asset_store or default_asset_store self._download_manager = download_manager or default_asset_download_manager @final def __call__( self, dataset_name_or_card: str | AssetCard, *, force: bool = False, progress: bool = True, ) -> DatasetT: if isinstance(dataset_name_or_card, AssetCard): card = dataset_name_or_card else: card = self._asset_store.retrieve_card(dataset_name_or_card) dataset_uri = card.field("data").as_uri() try: path = self._download_manager.download_dataset( dataset_uri, card.name, force=force, progress=progress ) except ValueError as ex: raise AssetCardError( f"The value of the field 'data' of the asset card '{card.name}' must be a URI. See nested exception for details." ) from ex try: return self._load(path, card) except ValueError as ex: raise AssetError( f"The {card.name} dataset cannot be loaded. See nested exception for details." ) from ex @abstractmethod def _load(self, path: Path, card: AssetCard) -> DatasetT: """ :param path: The path to the dataset. :param card: The asset card of the dataset. """
[docs] @final class DelegatingDatasetLoader(DatasetLoader[DatasetT]): """Loads datasets of type ``DatasetT`` using registered loaders.""" _asset_store: AssetStore _loaders: dict[str, DatasetLoader[DatasetT]] def __init__(self, *, asset_store: AssetStore | None = None) -> None: """ :param asset_store: The asset store where to check for available datasets. If ``None``, the default asset store will be used. """ self._asset_store = asset_store or default_asset_store self._loaders = {} def __call__( self, dataset_name_or_card: str | AssetCard, *, force: bool = False, progress: bool = True, ) -> DatasetT: if isinstance(dataset_name_or_card, AssetCard): card = dataset_name_or_card else: card = self._asset_store.retrieve_card(dataset_name_or_card) family = card.field("dataset_family").as_(str) try: loader = self._loaders[family] except KeyError: raise AssetError( f"The value of the field 'dataset_family' of the asset card '{card.name}' must be a supported dataset family, but is '{family}' instead." ) from None return loader(card, force=force, progress=progress)
[docs] def register(self, family: str, loader: DatasetLoader[DatasetT]) -> None: """Register a dataset loader to use with this loader. :param family: The dataset type. If the 'dataset_family' field of an asset card matches this value, the specified ``loader`` will be used. :param loader: The dataset loader. """ if family in self._loaders: raise ValueError( f"`family` must be a unique dataset family name, but '{family}' is already registered." ) self._loaders[family] = loader
[docs] def supports(self, dataset_name_or_card: str | AssetCard) -> bool: """Return ``True`` if the specified dataset has a registered loader.""" if isinstance(dataset_name_or_card, AssetCard): card = dataset_name_or_card else: card = self._asset_store.retrieve_card(dataset_name_or_card) family = card.field("dataset_family").as_(str) return family in self._loaders
[docs] def is_dataset_card(card: AssetCard) -> bool: """Return ``True`` if ``card`` specifies a dataset.""" return card.field("dataset_family").exists()
[docs] def get_dataset_family(card: AssetCard) -> str: """Return the dataset family name contained in ``card``.""" return cast(str, card.field("dataset_family").as_(str))