Source code for fairseq2.data_type

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
This module provides functions for managing PyTorch data types.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterator, Set
from contextlib import contextmanager
from typing import Any, TypeAlias, final

import torch
from torch import get_default_dtype
from torch.overrides import TorchFunctionMode
from typing_extensions import override

from fairseq2.runtime.dependency import get_dependency_resolver
from fairseq2.typing import ContextManager
from fairseq2.utils.threading import ThreadLocalStorage
from fairseq2.utils.warn import _warn_deprecated

DataType: TypeAlias = torch.dtype


[docs] def set_dtype(dtype: DataType) -> ContextManager[None]: """ Changes the floating-point data type of the calling thread to the specified type. This function acts as a context manager, ensuring that within its scope, any operation that constructs tensors uses the specified data type - unless an explicit ``dtype`` argument is provided. .. code:: python import torch from fairseq2.data_type import set_dtype with set_dtype(torch.bfloat16): t = torch.ones((4,4)) assert t.dtype == torch.bfloat16 with set_dtype(torch.float16): t = torch.ones((4, 4)) assert t.dtype == torch.float16 t = torch.ones((4, 4)) assert t.dtype == torch.float32 """ resolver = get_dependency_resolver() return resolver.resolve(DataTypeContext).set_dtype(dtype)
def default_dtype(dtype: DataType) -> ContextManager[None]: _warn_deprecated( "`default_dtype()` is deprecated and will be removed in v0.14. Use `set_dtype()` instead." ) return set_dtype(dtype)
[docs] def get_current_dtype() -> DataType: """ Returns the current floating-point data type of the calling thread. .. warning:: This function might impose a slight performance cost. Avoid calling it in hot code paths. """ resolver = get_dependency_resolver() return resolver.resolve(DataTypeContext).get_current_dtype()
[docs] class DataTypeContext(ABC): """ Provides methods to get and set the current floating-point data type of the calling thread. This interface can be used as an alternative to the corresponding standalone functions in object-oriented code. """
[docs] @abstractmethod def get_current_dtype(self) -> DataType: """See :func:`get_current_dtype`."""
[docs] @abstractmethod def set_dtype(self, dtype: DataType) -> ContextManager[None]: """See :func:`set_dtype`."""
@final class _StandardDataTypeContext(DataTypeContext): def __init__(self, mode_stack: _DataTypeModeStack) -> None: self._mode_stack = mode_stack @override def get_current_dtype(self) -> DataType: mode = self._mode_stack.maybe_get_top_mode() if mode is not None: return mode.dtype return get_default_dtype() @override @contextmanager def set_dtype(self, dtype: DataType) -> Iterator[None]: mode = self._mode_stack.set_mode(dtype) try: with mode: yield finally: self._mode_stack.pop_mode() @final class _DataTypeModeStack: def __init__(self, constructors: Set[Any], tls: ThreadLocalStorage) -> None: self._constructors = constructors self._tls = tls def set_mode(self, dtype: DataType) -> _DataTypeMode: mode = _DataTypeMode(dtype, self._constructors) modes = self._get_dtype_mode_stack() modes.append(mode) if len(modes) > 1: modes[-2].enabled = False return mode def pop_mode(self) -> None: modes = self._get_dtype_mode_stack() if len(modes) > 1: modes[-2].enabled = True modes.pop() def maybe_get_top_mode(self) -> _DataTypeMode | None: modes = self._get_dtype_mode_stack() if modes: return modes[-1] return None def _get_dtype_mode_stack(self) -> list[_DataTypeMode]: return self._tls.get("dtype_mode_stack", list) @final class _DataTypeMode(TorchFunctionMode): def __init__(self, dtype: DataType, constructors: Set[Any]) -> None: self._constructors = constructors self.dtype = dtype self.enabled = True def __torch_function__( # type: ignore[override] self, func: Any, types: Any, args: Any, kwargs: Any = None ) -> Any: if kwargs is None: kwargs = {} if self.enabled and func in self._constructors: dtype = kwargs.get("dtype", None) if dtype is None: kwargs["dtype"] = self.dtype return func(*args, **kwargs) def _tensor_constructors() -> set[Any]: # Taken from torch/utils/_device.py. return { torch.empty, torch.empty_permuted, torch.empty_strided, torch.empty_quantized, torch.ones, torch.arange, torch.bartlett_window, torch.blackman_window, torch.eye, torch.fft.fftfreq, torch.fft.rfftfreq, torch.full, torch.hamming_window, torch.hann_window, torch.kaiser_window, torch.linspace, torch.logspace, torch.nested.nested_tensor, torch.rand, torch.randn, torch.randint, torch.randperm, torch.range, torch.sparse_coo_tensor, torch.sparse_compressed_tensor, torch.sparse_csr_tensor, torch.sparse_csc_tensor, torch.sparse_bsr_tensor, torch.sparse_bsc_tensor, torch.tril_indices, torch.triu_indices, torch.zeros, torch.asarray, torch.tensor, torch.as_tensor, torch.scalar_tensor, }