# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from functools import partial
from typing import (
Any,
Callable,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
Union,
)
import torch
from ..._cpp_lib import _built_with_cuda
from ..common import BaseOperator
from .attn_bias import (
AttentionBias,
AttentionBiasSubTensor,
BlockDiagonalGappyKeysMask,
BlockDiagonalMask,
BlockDiagonalPaddedKeysMask,
LowerTriangularMask,
PagedBlockDiagonalGappyKeysMask,
PagedBlockDiagonalPaddedKeysMask,
)
def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
# NoneType
if isinstance(None, attn_bias_type):
return True
if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
return True
return False
def _attn_bias_apply(
attn_bias: Optional[Union[torch.Tensor, AttentionBias]],
op: Callable[[torch.Tensor], torch.Tensor],
) -> Optional[Union[torch.Tensor, AttentionBias]]:
if isinstance(attn_bias, torch.Tensor) and attn_bias.ndim != 0:
return op(attn_bias)
return attn_bias
class ScaledTensor(torch.Tensor):
__slots__ = ["scale", "dequant_func", "original_dtype"]
# Disabling custom torch function handling for this class
__torch_function__ = torch._C._disabled_torch_function_impl
def __new__(
cls,
data: torch.Tensor,
scale: torch.Tensor,
dequant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
original_dtype: torch.dtype,
require_grad: bool = False,
) -> "ScaledTensor":
"""
Creates a new ScaledTensor subclass instance.
Parameters:
- data: The underlying quantized tensor (e.g., int8, int4).
- scale: The scale tensor or scalar to be used for dequantization.
- dequant_func: A callable that applies dequantization, which takes both the data and scale as input.
- original_dtype: The data type before quantization (e.g., float32, float16).
- require_grad: Whether or not to track gradients (default: False for inference use).
"""
# Use _make_subclass to create a new ScaledTensor instance, which is a subclass of torch.Tensor.
instance = torch.Tensor._make_subclass(cls, data, require_grad=require_grad)
# Store the dequantization scale and function as attributes.
instance.scale = scale # type: ignore
instance.dequant_func = dequant_func # type: ignore
# Store the original data type of the tensor, so we can cast it back after dequantization.
instance.original_dtype = original_dtype # type: ignore
# Return the new instance of ScaledTensor.
return instance
def dequantize(self) -> torch.Tensor:
"""
Applies the custom dequantization function provided at the tensor's creation.
After dequantization, the data is cast back to its original data type.
"""
# Explicitly create a new torch.Tensor to ensure the return type is torch.Tensor, not ScaledTensor.
data = torch.Tensor(self.float())
# Call the dequantization function, passing in the data and the scale.
dequantized_data = self.dequant_func(data, self.scale) # type: ignore
# Cast the dequantized data back to the original data type.
return dequantized_data.to(self.original_dtype) # type: ignore
def unpack(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Unpacks the ScaledTensor by returning its data and scale as a tuple.
Returns:
- A tuple of (data, scale), both of which are torch.Tensor objects.
"""
return self.data, self.scale # type: ignore
def __repr__(self):
"""
Custom string representation for ScaledTensor.
"""
return f"ScaledTensor(data={self.data}, scale={self.scale}, original_dtype={self.original_dtype})"
def pack_fp8_tensorwise_per_head(
x: torch.Tensor, scale: Union[torch.Tensor, float], original_dtype
) -> ScaledTensor:
"""
Pack a tensor into a tensorwise fp8 ScaledTensor.
"""
if isinstance(scale, float):
scale = torch.tensor([scale], device=x.device)
def dequant_func(x, scale):
return x * scale[:, None, :, None]
return ScaledTensor(
data=x,
scale=scale,
dequant_func=dequant_func,
original_dtype=original_dtype,
)
@dataclass
class Inputs:
"""
Stores inputs to the `memory_efficient_attention` operators
"""
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
p: float = 0.0
scale: Optional[float] = None
output_dtype: Optional[torch.dtype] = None
is_partial: bool = False
@property
def device(self) -> torch.device:
return self.query.device
@property
def scale_float(self) -> float:
return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale
def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.query.ndim == 5:
return self.query, self.key, self.value
if self.query.ndim == 4:
return (
self.query.unsqueeze(2),
self.key.unsqueeze(2),
self.value.unsqueeze(2),
)
if self.value.ndim == 3:
return (
self.query[:, :, None, None],
self.key[:, :, None, None],
self.value[:, :, None, None],
)
assert False
def normalize_bmhk(self) -> Tuple[int, ...]:
if self.query.ndim not in [3, 4, 5]:
raise ValueError(
f"Invalid shape for query: {self.query.shape}. "
"Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
)
if self.value.dtype == torch.int32:
# Quantized K/V case, in which the last dims of Q and K are different.
# NB we currently don't have any implementations for quantized KV with
# SUPPORTS_DIFFERENT_VALUE_EMBED.
output_shape = tuple(self.query.shape)
else:
output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],)
# Convert from legacy format
if self.query.ndim == 3:
self.query = self.query.unsqueeze(2)
self.key = self.key.unsqueeze(2)
self.value = self.value.unsqueeze(2)
self.attn_bias = _attn_bias_apply(
self.attn_bias, partial(torch.unsqueeze, dim=1)
)
return output_shape
def validate_inputs(self) -> None:
qkv = (self.query, self.key, self.value)
if self.query.ndim not in (3, 4, 5) or any(
x.ndim != self.query.ndim for x in qkv
):
raise ValueError(
f"Query/Key/Value should all have BMGHK, BMHK or BMK shape.\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if any(x.device != self.query.device for x in qkv):
raise ValueError("Query/Key/Value should all be on the same device")
if isinstance(
self.attn_bias,
(
BlockDiagonalMask,
BlockDiagonalPaddedKeysMask,
PagedBlockDiagonalPaddedKeysMask,
BlockDiagonalGappyKeysMask,
PagedBlockDiagonalGappyKeysMask,
),
):
bias_device = self.attn_bias.q_seqinfo.seqstart.device
if bias_device != self.query.device:
raise ValueError(
f"Attention bias and Query/Key/Value should be on the same device\n"
f" query.device: {self.query.device}\n"
f" attn_bias : {bias_device}\n"
)
quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
if not (quantized_dtypes or non_quantized_dtypes):
raise ValueError(
"Query/Key/Value should either all have the same dtype, or "
"(in the quantized case) Key/Value should have dtype torch.int32\n"
f" query.dtype: {self.query.dtype}\n"
f" key.dtype : {self.key.dtype}\n"
f" value.dtype: {self.value.dtype}"
)
# Biases with tensors attached are meant to be in BMHK format
# This would require to permute biases/gradients which can be expensive,
# so let's just forbid it - BMK is a legacy format anyway
if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
type(self.attn_bias)
):
raise ValueError(
f"Please provide inputs in BMHK format rather "
f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
)
attn_bias_t: Optional[torch.Tensor] = None
if isinstance(self.attn_bias, AttentionBiasSubTensor):
if self.attn_bias.HOLDS_DENSE_TENSOR:
attn_bias_t = self.attn_bias._subtensor
elif isinstance(self.attn_bias, torch.Tensor):
attn_bias_t = self.attn_bias
if self.query.ndim == 4 and attn_bias_t is not None:
expected_shape = (
self.query.shape[0],
self.query.shape[2],
self.query.shape[1],
self.key.shape[1],
)
if attn_bias_t.shape != expected_shape:
raise ValueError(
f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if isinstance(self.attn_bias, BlockDiagonalMask):
if any(x.shape[0] != 1 for x in qkv):
raise ValueError(
f"Expected batch_size=1 when using block-diagonal bias\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if self.p < 0.0 or self.p > 1.0:
raise ValueError(f"Invalid dropout probability: p={self.p}")
# Check that shapes match between inputs
B, Mq = self.query.shape[:2]
K = self.query.shape[-1]
B, Mkv = self.key.shape[:2]
Kv = self.value.shape[-1]
quantized_kv_cache = self.value.dtype == torch.int32
key_embed_dim = Kv if quantized_kv_cache else K
valid_shapes = True
if self.query.ndim == 3: # BMK
valid_shapes = (
self.query.shape == (B, Mq, K)
and self.key.shape == (B, Mkv, K)
and self.value.shape == (B, Mkv, Kv)
)
H = self.query.shape[-2]
if self.query.ndim == 4: # BMHK
valid_shapes = (
self.query.shape == (B, Mq, H, K)
and self.key.shape == (B, Mkv, H, key_embed_dim)
and self.value.shape == (B, Mkv, H, Kv)
)
G = self.query.shape[2]
if self.query.ndim == 5: # BMNHK
valid_shapes = (
self.query.shape == (B, Mq, G, H, K)
and self.key.shape == (B, Mkv, G, H, key_embed_dim)
and self.value.shape == (B, Mkv, G, H, Kv)
)
if not valid_shapes:
raise ValueError(
f"Incompatible shapes for attention inputs:\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}\n"
"HINT: We don't support broadcasting, please use `expand` "
"yourself before calling `memory_efficient_attention` if you need to"
)
def get_output_dtype(self) -> torch.dtype:
if self.output_dtype is None:
if self.is_partial and self.query.dtype is not torch.float64:
return torch.float32
return self.query.dtype
return self.output_dtype
@property
def nbytes(self) -> int:
"""
Number of bytes in the input, not counting the attention bias.
"""
return sum(
x.untyped_storage().nbytes() for x in [self.query, self.key, self.value]
)
@dataclass
class Context:
lse: torch.Tensor
out: torch.Tensor
# NOTE: If `rng_state` is set, `op_bw` should be set as well
# as the randomness is backend-dependant
op_bw: Optional[Type["AttentionBwOpBase"]] = None
rng_state: Optional[Any] = None
qkv_share_storage: bool = False
def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
lse = self.lse
if pad_amount > 0:
if force_pad_inf:
lse = lse[:, :, : self.out.shape[1]]
pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
lse[:, :, self.out.shape[1] :].fill_(math.inf)
return lse
@dataclass
class Gradients:
dq: torch.Tensor
dk: torch.Tensor
dv: torch.Tensor
# bias gradient. None if there is no tensor bias or if it doesn't require grad
db: Optional[torch.Tensor] = None
[docs]class AttentionOpBase(BaseOperator):
"""Base class for any attention operator in xFormers
See:
- :attr:`xformers.ops.fmha.cutlass.FwOp`
- :attr:`xformers.ops.fmha.cutlass.BwOp`
- :attr:`xformers.ops.fmha.flash.FwOp`
- :attr:`xformers.ops.fmha.flash.BwOp`
- :attr:`xformers.ops.fmha.triton.FwOp`
- :attr:`xformers.ops.fmha.triton.BwOp`
"""
OPERATOR: Any
SUPPORTED_DEVICES: Set[str]
CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
SUPPORTED_DTYPES: Set[torch.dtype]
SUPPORTED_MAX_K: float
SUPPORTED_MIN_K: int = 0
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (type(None),)
SUPPORTS_DROPOUT: bool
SUPPORTS_CUSTOM_SCALE: bool = False
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
SUPPORTS_OUTPUT_DTYPE: bool = False
SUPPORTS_PARTIAL: bool = False
IS_DETERMINISTIC: bool = True
SUPPORTS_BMGHK: bool = False
NAME: str
OPERATOR_CATEGORY = "memory_efficient_attention"
# Format for the LSE computed in the FW pass, and accepted in the BW pass,
# for BlockDiagonalMask and children.
# When using a varlen bias, both the FW and BW operators must have the
# same value for `VARLEN_LSE_PACKED`
VARLEN_LSE_PACKED: bool = True
_TEST_BATCH_SIZES: List[int] = [1, 300]
_TEST_K: List[int] = [32, 128]
@classmethod
def supports(cls, d: Inputs) -> bool:
return not cls.not_supported_reasons(d)
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = []
if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
reasons.append("query.shape[-1] != value.shape[-1]")
if max(K, Kv) > cls.SUPPORTED_MAX_K:
reasons.append(
f"max(query.shape[-1], value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
)
if min(K, Kv) < cls.SUPPORTED_MIN_K:
reasons.append(
f"min(query.shape[-1], value.shape[-1]) < {cls.SUPPORTED_MIN_K}"
)
return reasons
[docs] @classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
"""
Returns a list of reasons why this is not supported.
The kernel can run these inputs only if the returned list is empty
"""
query_shape = d.query.shape
reasons = cls.shape_not_supported_reasons(
Mq=query_shape[1],
Mkv=d.key.shape[1],
K=query_shape[-1],
Kv=query_shape[-1] if d.value.dtype == torch.int32 else d.value.shape[-1],
)
device_type = d.query.device.type
dtype = d.query.dtype
if device_type not in cls.SUPPORTED_DEVICES:
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
if (
device_type == "cuda"
and not _built_with_cuda
and (torch.version.hip is None)
):
reasons.append("xFormers wasn't build with CUDA support")
if device_type == "cuda" and (torch.version.hip is None):
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
reasons.append(
f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
f"but your GPU has capability {device_capability} (too old)"
)
if dtype not in cls.SUPPORTED_DTYPES:
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
if not cls.SUPPORTS_OUTPUT_DTYPE:
if d.output_dtype is not None and d.output_dtype is not dtype:
reasons.append("Custom output dtype not supported")
if d.is_partial and not cls.SUPPORTS_PARTIAL:
reasons.append("Partial attention not supported")
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
reasons.append("dropout > 0.0")
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
reasons.append("has custom scale")
# bfloat16 is only supported on A100+
# ... although the kernels can still run and give the
# correct result
if dtype is torch.bfloat16 and (
not device_type.startswith("cuda")
or torch.cuda.get_device_capability(d.query.device)[0] < 8
):
reasons.append("bf16 is only supported on A100+ GPUs")
if not cls.is_available():
reasons.append(
"operator wasn't built - see `python -m xformers.info` for more info"
)
if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
reasons.append(
"operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
)
if not cls.SUPPORTS_BMGHK and d.query.ndim == 5:
reasons.append("operator does not support BMGHK format")
return reasons
class AttentionFwOpBase(AttentionOpBase):
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 3e-4,
torch.half: 4e-3,
torch.bfloat16: 2e-2,
}
ERROR_RTOL: Mapping[torch.dtype, float] = {
torch.float: 2e-5,
torch.half: 4e-4,
torch.bfloat16: 5e-3,
}
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
raise NotImplementedError()
class AttentionBwOpBase(AttentionOpBase):
# NOTE on tolerances: These are tested for `scales => (1/32)**0.5`
# In the BW pass, imprecisions accumulate in the Q@K.T recalculation
# These imprecisions are multiplied by the `scale` and then exponentiated
# So if the scale is too high, we get a lot of errors
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 9e-4,
torch.half: 0.2,
torch.bfloat16: 0.9,
}
ERROR_RTOL: Mapping[torch.dtype, float] = {
torch.float: 1e-4,
torch.half: 2e-2,
torch.bfloat16: 0.1,
}
SUPPORTS_ATTN_BIAS_GRAD = False
SUPPORTS_PARTIAL = True
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d)
if (
isinstance(d.attn_bias, torch.Tensor)
and d.attn_bias.requires_grad
and not cls.SUPPORTS_ATTN_BIAS_GRAD
):
reasons.append(
"Computing the bias gradient is not supported (attn_bias.requires_grad = True)"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
raise NotImplementedError()
AttentionOp = Tuple[
Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
]
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
if tensor.ndim == 4:
return tensor
return tensor.reshape(
[tensor.shape[0] // num_heads, num_heads, tensor.shape[1], tensor.shape[2]]
).permute((0, 2, 1, 3))
def check_lastdim_alignment_stride1(
reasons: List[str], name: str, x: torch.Tensor, alignment: int
) -> None:
if x.shape[-1] % alignment != 0:
reasons.append(f"{name}.shape[-1] % {alignment} != 0")
elif x.stride(-2) % alignment != 0:
reasons.append(
f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
)
# We can have stride=0 sometimes if dimension=1
if x.stride(-1) > 1:
reasons.append(
f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
)