fairseq2.data_type

This module provides functions for managing PyTorch data types.

ABCs

class fairseq2.data_type.DataTypeContext[source]

Bases: ABC

Provides methods to get and set the current floating-point data type of the calling thread.

This interface can be used as an alternative to the corresponding standalone functions in object-oriented code.

abstract get_current_dtype() dtype[source]

See get_current_dtype().

abstract set_dtype(dtype: dtype) AbstractContextManager[None][source]

See set_dtype().

Functions

fairseq2.data_type.get_current_dtype() dtype[source]

Returns the current floating-point data type of the calling thread.

Warning

This function might impose a slight performance cost. Avoid calling it in hot code paths.

fairseq2.data_type.set_dtype(dtype: dtype) AbstractContextManager[None][source]

Changes the floating-point data type of the calling thread to the specified type.

This function acts as a context manager, ensuring that within its scope, any operation that constructs tensors uses the specified data type - unless an explicit dtype argument is provided.

import torch

from fairseq2.data_type import set_dtype

with set_dtype(torch.bfloat16):
    t = torch.ones((4,4))

    assert t.dtype == torch.bfloat16

    with set_dtype(torch.float16):
        t = torch.ones((4, 4))

        assert t.dtype == torch.float16

t = torch.ones((4, 4))

assert t.dtype == torch.float32