fairseq2.data_type¶
This module provides functions for managing PyTorch data types.
ABCs¶
- class fairseq2.data_type.DataTypeContext[source]¶
Bases:
ABCProvides methods to get and set the current floating-point data type of the calling thread.
This interface can be used as an alternative to the corresponding standalone functions in object-oriented code.
- abstract get_current_dtype() dtype[source]¶
See
get_current_dtype().
- abstract set_dtype(dtype: dtype) AbstractContextManager[None][source]¶
See
set_dtype().
Functions¶
- fairseq2.data_type.get_current_dtype() dtype[source]¶
Returns the current floating-point data type of the calling thread.
Warning
This function might impose a slight performance cost. Avoid calling it in hot code paths.
- fairseq2.data_type.set_dtype(dtype: dtype) AbstractContextManager[None][source]¶
Changes the floating-point data type of the calling thread to the specified type.
This function acts as a context manager, ensuring that within its scope, any operation that constructs tensors uses the specified data type - unless an explicit
dtypeargument is provided.import torch from fairseq2.data_type import set_dtype with set_dtype(torch.bfloat16): t = torch.ones((4,4)) assert t.dtype == torch.bfloat16 with set_dtype(torch.float16): t = torch.ones((4, 4)) assert t.dtype == torch.float16 t = torch.ones((4, 4)) assert t.dtype == torch.float32