Shortcuts

Source code for xformers.ops.fmha.common

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import math
from dataclasses import dataclass
from functools import partial
from typing import (
    Any,
    Callable,
    Iterable,
    List,
    Mapping,
    Optional,
    Set,
    Tuple,
    Type,
    Union,
)

import torch

from ..._cpp_lib import _built_with_cuda
from ..common import BaseOperator
from .attn_bias import (
    AttentionBias,
    BlockDiagonalMask,
    LowerTriangularMask,
    LowerTriangularMaskWithTensorBias,
)


def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
    # NoneType
    if isinstance(None, attn_bias_type):
        return True
    if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
        return True
    return False


def _attn_bias_apply(
    attn_bias: Optional[Union[torch.Tensor, AttentionBias]],
    op: Callable[[torch.Tensor], torch.Tensor],
) -> Optional[Union[torch.Tensor, AttentionBias]]:
    if isinstance(attn_bias, torch.Tensor):
        return op(attn_bias)
    if isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
        return LowerTriangularMaskWithTensorBias(op(attn_bias._bias))
    return attn_bias


@dataclass
class Inputs:
    """
    Stores inputs to the `memory_efficient_attention` operators
    """

    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor
    attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
    p: float = 0.0
    scale: Optional[float] = None
    output_dtype: Optional[torch.dtype] = None
    is_partial: bool = False

    @property
    def device(self) -> torch.device:
        return self.query.device

    @property
    def scale_float(self) -> float:
        return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale

    def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if self.query.ndim == 5:
            return self.query, self.key, self.value
        if self.query.ndim == 4:
            return (
                self.query.unsqueeze(2),
                self.key.unsqueeze(2),
                self.value.unsqueeze(2),
            )
        if self.value.ndim == 3:
            return (
                self.query[:, :, None, None],
                self.key[:, :, None, None],
                self.value[:, :, None, None],
            )
        assert False

    def normalize_bmhk(self) -> Tuple[int, ...]:
        if self.query.ndim not in [3, 4, 5]:
            raise ValueError(
                f"Invalid shape for query: {self.query.shape}. "
                "Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
                ", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
            )
        if self.value.dtype == torch.int32:
            # Quantized K/V case, in which the last dims of Q and K are different.
            # NB we currently don't have any implementations for quantized KV with
            # SUPPORTS_DIFFERENT_VALUE_EMBED.
            output_shape = tuple(self.query.shape)
        else:
            output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],)
        # Convert from legacy format
        if self.query.ndim == 3:
            self.query = self.query.unsqueeze(2)
            self.key = self.key.unsqueeze(2)
            self.value = self.value.unsqueeze(2)
            self.attn_bias = _attn_bias_apply(
                self.attn_bias, partial(torch.unsqueeze, dim=1)
            )
        return output_shape

    def validate_inputs(self) -> None:
        qkv = (self.query, self.key, self.value)
        if self.query.ndim not in (3, 4, 5) or any(
            x.ndim != self.query.ndim for x in qkv
        ):
            raise ValueError(
                f"Query/Key/Value should all have BMGHK, BMHK or BMK shape.\n"
                f"  query.shape: {self.query.shape}\n"
                f"  key.shape  : {self.key.shape}\n"
                f"  value.shape: {self.value.shape}"
            )
        if any(x.device != self.query.device for x in qkv):
            raise ValueError("Query/Key/Value should all be on the same device")
        quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
        non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
        if not (quantized_dtypes or non_quantized_dtypes):
            raise ValueError(
                "Query/Key/Value should either all have the same dtype, or "
                "(in the quantized case) Key/Value should have dtype torch.int32\n"
                f"  query.dtype: {self.query.dtype}\n"
                f"  key.dtype  : {self.key.dtype}\n"
                f"  value.dtype: {self.value.dtype}"
            )
        # Biases with tensors attached are meant to be in BMHK format
        # This would require to permute biases/gradients which can be expensive,
        # so let's just forbid it - BMK is a legacy format anyway
        if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
            type(self.attn_bias)
        ):
            raise ValueError(
                f"Please provide inputs in BMHK format rather "
                f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
            )
        attn_bias_t: Optional[torch.Tensor] = None
        if isinstance(self.attn_bias, torch.Tensor):
            attn_bias_t = self.attn_bias
        if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias):
            attn_bias_t = self.attn_bias._bias
        if self.query.ndim == 4 and attn_bias_t is not None:
            expected_shape = (
                self.query.shape[0],
                self.query.shape[2],
                self.query.shape[1],
                self.key.shape[1],
            )
            if attn_bias_t.shape != expected_shape:
                raise ValueError(
                    f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
                    f"  query.shape: {self.query.shape}\n"
                    f"  key.shape  : {self.key.shape}\n"
                    f"  value.shape: {self.value.shape}"
                )
        if isinstance(self.attn_bias, BlockDiagonalMask):
            if any(x.shape[0] != 1 for x in qkv):
                raise ValueError(
                    f"Expected batch_size=1 when using block-diagonal bias\n"
                    f"  query.shape: {self.query.shape}\n"
                    f"  key.shape  : {self.key.shape}\n"
                    f"  value.shape: {self.value.shape}"
                )
        if self.p < 0.0 or self.p > 1.0:
            raise ValueError(f"Invalid dropout probability: p={self.p}")
        # Check that shapes match between inputs
        B, Mq = self.query.shape[:2]
        K = self.query.shape[-1]
        B, Mkv = self.key.shape[:2]
        Kv = self.value.shape[-1]
        quantized_kv_cache = self.value.dtype == torch.int32
        key_embed_dim = Kv if quantized_kv_cache else K

        valid_shapes = True
        if self.query.ndim == 3:  # BMK
            valid_shapes = (
                self.query.shape == (B, Mq, K)
                and self.key.shape == (B, Mkv, K)
                and self.value.shape == (B, Mkv, Kv)
            )
        H = self.query.shape[-2]
        if self.query.ndim == 4:  # BMHK
            valid_shapes = (
                self.query.shape == (B, Mq, H, K)
                and self.key.shape == (B, Mkv, H, key_embed_dim)
                and self.value.shape == (B, Mkv, H, Kv)
            )
        G = self.query.shape[2]
        if self.query.ndim == 5:  # BMNHK
            valid_shapes = (
                self.query.shape == (B, Mq, G, H, K)
                and self.key.shape == (B, Mkv, G, H, key_embed_dim)
                and self.value.shape == (B, Mkv, G, H, Kv)
            )
        if not valid_shapes:
            raise ValueError(
                f"Incompatible shapes for attention inputs:\n"
                f"  query.shape: {self.query.shape}\n"
                f"  key.shape  : {self.key.shape}\n"
                f"  value.shape: {self.value.shape}\n"
                "HINT: We don't support broadcasting, please use `expand` "
                "yourself before calling `memory_efficient_attention` if you need to"
            )

    def get_output_dtype(self) -> torch.dtype:
        if self.output_dtype is None:
            if self.is_partial and self.query.dtype is not torch.float64:
                return torch.float32
            return self.query.dtype
        return self.output_dtype


@dataclass
class Context:
    lse: torch.Tensor
    out: torch.Tensor
    # NOTE: If `rng_state` is set, `op_bw` should be set as well
    # as the randomness is backend-dependant
    op_bw: Optional[Type["AttentionBwOpBase"]] = None
    rng_state: Optional[Any] = None
    qkv_share_storage: bool = False

    def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
        pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
        lse = self.lse
        if pad_amount > 0:
            if force_pad_inf:
                lse = lse[:, :, : self.out.shape[1]]
                pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
            lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
        elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
            lse[:, :, self.out.shape[1] :].fill_(math.inf)
        return lse


@dataclass
class Gradients:
    dq: torch.Tensor
    dk: torch.Tensor
    dv: torch.Tensor
    # bias gradient. None if there is no tensor bias or if it doesn't require grad
    db: Optional[torch.Tensor] = None


[docs]class AttentionOpBase(BaseOperator): """Base class for any attention operator in xFormers See: - :attr:`xformers.ops.fmha.cutlass.FwOp` - :attr:`xformers.ops.fmha.cutlass.BwOp` - :attr:`xformers.ops.fmha.flash.FwOp` - :attr:`xformers.ops.fmha.flash.BwOp` - :attr:`xformers.ops.fmha.triton.FwOp` - :attr:`xformers.ops.fmha.triton.BwOp` - :attr:`xformers.ops.fmha.small_k.FwOp` - :attr:`xformers.ops.fmha.small_k.BwOp` """ OPERATOR: Any SUPPORTED_DEVICES: Set[str] CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0) SUPPORTED_DTYPES: Set[torch.dtype] SUPPORTED_MAX_K: float SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (type(None),) SUPPORTS_DROPOUT: bool SUPPORTS_CUSTOM_SCALE: bool = False SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False SUPPORTS_OUTPUT_DTYPE: bool = False SUPPORTS_PARTIAL: bool = False IS_DETERMINISTIC: bool = True SUPPORTS_BMGHK: bool = False NAME: str OPERATOR_CATEGORY = "memory_efficient_attention" _TEST_BATCH_SIZES: List[int] = [1, 300] _TEST_K: List[int] = [32, 128] @classmethod def supports(cls, d: Inputs) -> bool: return not cls.not_supported_reasons(d) @classmethod def shape_not_supported_reasons( cls, Mq: int, Mkv: int, K: int, Kv: int ) -> List[str]: reasons = [] if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv: reasons.append("query.shape[-1] != value.shape[-1]") if max(K, Kv) > cls.SUPPORTED_MAX_K: reasons.append( f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}" ) return reasons
[docs] @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: """ Returns a list of reasons why this is not supported. The kernel can run these inputs only if the returned list is empty """ query_shape = d.query.shape reasons = cls.shape_not_supported_reasons( Mq=query_shape[1], Mkv=d.key.shape[1], K=query_shape[-1], Kv=query_shape[-1] if d.value.dtype == torch.int32 else d.value.shape[-1], ) device_type = d.query.device.type dtype = d.query.dtype if device_type not in cls.SUPPORTED_DEVICES: reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") if ( device_type == "cuda" and not _built_with_cuda and (torch.version.hip is None) ): reasons.append("xFormers wasn't build with CUDA support") if device_type == "cuda": device_capability = torch.cuda.get_device_capability(d.device) if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: reasons.append( f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} " f"but your GPU has capability {device_capability} (too old)" ) if dtype not in cls.SUPPORTED_DTYPES: reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})") if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES: reasons.append(f"attn_bias type is {type(d.attn_bias)}") if not cls.SUPPORTS_OUTPUT_DTYPE: if d.output_dtype is not None and d.output_dtype is not dtype: reasons.append("Custom output dtype not supported") if d.is_partial and not cls.SUPPORTS_PARTIAL: reasons.append("Partial attention not supported") if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT: reasons.append("dropout > 0.0") if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE: reasons.append("has custom scale") # bfloat16 is only supported on A100+ # ... although the kernels can still run and give the # correct result if dtype is torch.bfloat16 and ( not device_type.startswith("cuda") or torch.cuda.get_device_capability(d.query.device)[0] < 8 ): reasons.append("bf16 is only supported on A100+ GPUs") if not cls.is_available(): reasons.append( "operator wasn't built - see `python -m xformers.info` for more info" ) if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled(): reasons.append( "operator is non-deterministic, but `torch.use_deterministic_algorithms` is set" ) if not cls.SUPPORTS_BMGHK and d.query.ndim == 5: reasons.append("operator does not support BMGHK format") return reasons
class AttentionFwOpBase(AttentionOpBase): ERROR_ATOL: Mapping[torch.dtype, float] = { torch.float: 3e-4, torch.half: 4e-3, torch.bfloat16: 2e-2, } ERROR_RTOL: Mapping[torch.dtype, float] = { torch.float: 2e-5, torch.half: 4e-4, torch.bfloat16: 5e-3, } @classmethod def apply( cls, inp: Inputs, needs_gradient: bool ) -> Tuple[torch.Tensor, Optional[Context]]: raise NotImplementedError() @classmethod def attn_operator_flop( cls, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, causal: bool = False, seqstart_k: Optional[torch.Tensor] = None, seqstart_q: Optional[torch.Tensor] = None, ) -> int: """ Computes total flops for the attention Assumes inputs in format BMHK """ assert query.ndim == 4 if seqstart_q is not None: seqstart_q_py = seqstart_q.tolist() else: seqstart_q_py = [0, query.shape[1]] if seqstart_k is not None: seqstart_k_py = seqstart_k.tolist() else: seqstart_k_py = [0, key.shape[1]] total_flop = 0 for q_start, q_end, k_start, k_end in zip( seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] ): num_q = q_end - q_start num_kv = k_end - k_start # (M,K) @ (K,N) GEMM needs M*N*K*2 flop # Q @ K.transpose total_flop += num_q * num_kv * query.shape[-1] * 2 # (ignore softmax) # attn @ V total_flop += num_q * key.shape[-1] * num_kv * 2 # Multiply by num_heads and batches total_flop = total_flop * value.shape[2] * value.shape[0] if causal: total_flop //= 2 return total_flop class AttentionBwOpBase(AttentionOpBase): # NOTE on tolerances: These are tested for `scales => (1/32)**0.5` # In the BW pass, imprecisions accumulate in the Q@K.T recalculation # These imprecisions are multiplied by the `scale` and then exponentiated # So if the scale is too high, we get a lot of errors ERROR_ATOL: Mapping[torch.dtype, float] = { torch.float: 9e-4, torch.half: 0.2, torch.bfloat16: 0.9, } ERROR_RTOL: Mapping[torch.dtype, float] = { torch.float: 1e-4, torch.half: 2e-2, torch.bfloat16: 0.1, } SUPPORTS_ATTN_BIAS_GRAD = False @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d) if ( isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.requires_grad and not cls.SUPPORTS_ATTN_BIAS_GRAD ): reasons.append( "Computing the bias gradient is not supported (attn_bias.requires_grad = True)" ) return reasons @classmethod def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: raise NotImplementedError() @classmethod def attn_operator_flop( cls, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, causal: bool = False, seqstart_k: Optional[torch.Tensor] = None, seqstart_q: Optional[torch.Tensor] = None, ) -> int: """ Computes total flops for the attention Assumes inputs in format BMHK """ assert query.ndim == 4 if seqstart_q is not None: seqstart_q_py = seqstart_q.tolist() else: seqstart_q_py = [0, query.shape[1]] if seqstart_k is not None: seqstart_k_py = seqstart_k.tolist() else: seqstart_k_py = [0, key.shape[1]] total_flop = 0 for q_start, q_end, k_start, k_end in zip( seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] ): num_q = q_end - q_start num_kv = k_end - k_start Kqk = query.shape[-1] Kv = value.shape[-1] # (M,K) @ (K,N) GEMM needs M*N*K*2 flop # att = Q @ K.transpose total_flop += num_q * num_kv * Kqk * 2 # att @ dO total_flop += num_kv * num_q * Kv * 2 # dov = dO @ V total_flop += num_q * Kv * num_kv * 2 # dov @ K total_flop += num_q * Kqk * num_kv * 2 # dov @ Q total_flop += num_q * Kqk * num_kv * 2 # Multiply by num_heads and batches total_flop = total_flop * value.shape[2] * value.shape[0] if causal: total_flop //= 2 return total_flop AttentionOp = Tuple[ Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]] ] @dataclass class AttentionOpDispatch: """Dispatcher to automatically select the best operator to run memory-efficient attention. :Deprecated: This class is deprecated and will be removed in a later version """ op: AttentionOp @classmethod def from_arguments( cls, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, p: float = 0.0, scale: Optional[float] = None, ) -> "AttentionOpDispatch": """Here for backward compatibility""" from .dispatch import _dispatch_bw, _dispatch_fw inp = Inputs( query=query, key=key, value=value, attn_bias=attn_bias, p=p, scale=scale, ) return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp))) def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: if tensor.ndim == 4: return tensor return tensor.reshape( [tensor.shape[0] // num_heads, num_heads, tensor.shape[1], tensor.shape[2]] ).permute((0, 2, 1, 3)) def check_lastdim_alignment_stride1( reasons: List[str], name: str, x: torch.Tensor, alignment: int ) -> None: if x.shape[-1] % alignment != 0: reasons.append(f"{name}.shape[-1] % {alignment} != 0") elif x.stride(-2) % alignment != 0: reasons.append( f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})" ) # We can have stride=0 sometimes if dimension=1 if x.stride(-1) > 1: reasons.append( f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input" )