# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Optional, Sequence, Tuple, Type, Union, cast
import torch
from . import (
attn_bias,
ck,
ck_decoder,
ck_splitk,
cutlass,
decoder,
flash,
small_k,
triton_splitk,
)
from .attn_bias import (
AttentionBias,
BlockDiagonalGappyKeysMask,
BlockDiagonalMask,
BlockDiagonalPaddedKeysMask,
LowerTriangularFromBottomRightMask,
LowerTriangularMask,
PagedBlockDiagonalPaddedKeysMask,
)
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
AttentionOp,
AttentionOpBase,
AttentionOpDispatch,
Context,
Gradients,
Inputs,
bmk2bmhk,
)
from .dispatch import _dispatch_bw, _dispatch_fw, _ensure_op_supports_or_raise
MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
MemoryEfficientAttentionDecoderOp = (decoder.FwOp, cutlass.BwOp)
MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
MemoryEfficientAttentionOp = (small_k.FwOp, small_k.BwOp)
MemoryEfficientAttentionCkOp = (ck.FwOp, ck.BwOp)
MemoryEfficientAttentionCkDecoderOp = (ck_decoder.FwOp, ck.BwOp)
MemoryEfficientAttentionSplitKCkOp = (ck_splitk.FwOp, ck.BwOp)
def _deserialize_bias(attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]) -> Any:
if attn_bias_tensor is None:
return attn_bias_ctx
return attn_bias_tensor
class _fMHA(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, op: AttentionOp, *args: Any) -> Any:
inp = Inputs(*args)
op_fw = op[0] if op is not None else None
op_bw = op[1] if op is not None else None
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
inp=inp, op=op_fw
)
# Saving attn_bias is a bit complicated, as the
# torch part should go in `save_for_backward`
if isinstance(inp.attn_bias, torch.Tensor):
attn_bias_tensor = inp.attn_bias
attn_bias_ctx = None
else:
attn_bias_tensor = None
attn_bias_ctx = inp.attn_bias
ctx.save_for_backward(
inp.query,
inp.key,
inp.value,
op_ctx.out,
op_ctx.lse,
)
ctx.rng_state = op_ctx.rng_state
ctx.attn_bias_tensor = attn_bias_tensor
if op_ctx.op_bw is not None:
if op_bw is not None and op_bw is not op_ctx.op_bw:
raise ValueError(
f"Specified op_bw={op_bw.NAME}, but forward op "
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
)
op_bw = op_ctx.op_bw
if op_bw is None and (
inp.query.requires_grad or inp.key.requires_grad or inp.value.requires_grad
):
# NOTE: We need to check tensor strides to decide which operator we run in the BW pass.
# Unfortunately, PyTorch only allows to call this function during the FW pass, so
# we decide the operator to use now.
op_bw = _dispatch_bw(inp)
ctx.op_fw = op_fw
ctx.op_bw = op_bw
ctx.p = inp.p
# This allows to create gradients from a single storage,
# to avoid a "cat" in the BW pass.
# The heuristic is approximative, but:
# (1) It's not a big issue to create a shared storage
# (2) The heuristic needs to pass `torch.compile`
# (this is also why we run it in the FW pass, the BW pass is stricter)
ctx.qkv_share_storage = (
inp.query.shape[0] == inp.key.shape[0]
and inp.query.shape[-1] == inp.value.shape[-1]
and inp.query.stride(-2)
== (inp.key.shape[-1] + inp.query.shape[-1] + inp.value.shape[-1])
)
ctx.scale = inp.scale
ctx.attn_bias_ctx = attn_bias_ctx
ctx.n_args = len(args)
return out
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad):
# Re-create context
query, key, value, out, lse = ctx.saved_tensors
attn_bias_tensor = ctx.attn_bias_tensor
rng_state = ctx.rng_state
inp = Inputs(
query=query,
key=key,
value=value,
attn_bias=_deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
p=ctx.p,
scale=ctx.scale,
)
op_ctx = Context(
lse=lse,
out=out,
rng_state=rng_state,
)
grads = _memory_efficient_attention_backward(
ctx=op_ctx,
inp=inp,
grad=grad,
op=ctx.op_bw,
_skip_op_checks=True,
)
return (None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
ctx.n_args - 2
)
[docs]def memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[AttentionOp] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Implements the memory-efficient attention mechanism following
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
:Inputs shape:
- Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
the sequence length, H the number of heads, and K the embeding size per head
- If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
- Inputs can also be of dimension 5 with GQA - see note below
- Inputs can be non-contiguous - we only require the last dimension's stride to be 1
:Equivalent pytorch code:
.. code-block:: python
scale = 1.0 / query.shape[-1] ** 0.5
query = query * scale
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
attn = attn @ value
return attn.transpose(1, 2)
:Examples:
.. code-block:: python
import xformers.ops as xops
# Compute regular attention
y = xops.memory_efficient_attention(q, k, v)
# With a dropout of 0.2
y = xops.memory_efficient_attention(q, k, v, p=0.2)
# Causal attention
y = xops.memory_efficient_attention(
q, k, v,
attn_bias=xops.LowerTriangularMask()
)
:Supported hardware:
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
:EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
MQA/GQA is an experimental feature supported only for the forward pass.
If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
``H`` is the number of heads per group (8 in the example).
Please note that xFormers will not automatically broadcast the inputs, so you will need
to broadcast it manually before calling `memory_efficient_attention`.
:GQA/MQA example:
.. code-block:: python
import torch
import xformers.ops as xops
B, M, K = 3, 32, 128
kwargs = dict(device="cuda", dtype=torch.float16)
q = torch.randn([B, M, 8, K], **kwargs)
k = torch.randn([B, M, 2, K], **kwargs)
v = torch.randn([B, M, 2, K], **kwargs)
out_gqa = xops.memory_efficient_attention(
q.reshape([B, M, 2, 4, K]),
k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
)
Raises:
NotImplementedError: if there is no operator available to compute the MHA
ValueError: if inputs are invalid
:parameter query: Tensor of shape ``[B, Mq, H, K]``
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
:parameter p: Dropout probability. Disabled if set to ``0.0``
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
scale (q.shape[-1]**-0.5) will be used.
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
If set to ``None`` (recommended), xFormers \
will dispatch to the best available operator, depending on the inputs \
and options.
:return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
"""
return _memory_efficient_attention(
Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
),
op=op,
)
[docs]def memory_efficient_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
"""
return _memory_efficient_attention_forward(
Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
),
op=op,
)
[docs]def memory_efficient_attention_forward_requires_grad(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
"""
if p != 0.0:
raise NotImplementedError(
"dropout is not supported on the non-autograd API."
" If you want to use dropout, please call `memory_efficient_attention` directly"
)
out, ctx = _memory_efficient_attention_forward_requires_grad(
Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
),
op=op,
)
return out, ctx.lse
[docs]def memory_efficient_attention_backward(
grad: torch.Tensor,
output: torch.Tensor,
lse: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionBwOpBase]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the gradient of the attention.
Returns a tuple (dq, dk, dv)
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
`lse` is the tensor returned by
:attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
"""
if p != 0.0:
raise NotImplementedError(
"dropout is not supported on the non-autograd API."
" If you want to use dropout, please call `memory_efficient_attention` directly"
)
gradients = _memory_efficient_attention_backward(
Context(out=output, lse=lse),
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
),
grad,
op=op,
)
return (gradients.dq, gradients.dk, gradients.dv)
def _memory_efficient_attention(
inp: Inputs, op: Optional[AttentionOp] = None
) -> torch.Tensor:
# fast-path that doesn't require computing the logsumexp for backward computation
if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
return _memory_efficient_attention_forward(
inp, op=op[0] if op is not None else None
)
output_shape = inp.normalize_bmhk()
return _fMHA.apply(
op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
).reshape(output_shape)
def _memory_efficient_attention_forward(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> torch.Tensor:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp, False)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
out, *_ = op.apply(inp, needs_gradient=False)
return out.reshape(output_shape)
def _memory_efficient_attention_forward_requires_grad(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> Tuple[torch.Tensor, Context]:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp, True)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
out = op.apply(inp, needs_gradient=True)
assert out[1] is not None
return (out[0].reshape(output_shape), out[1])
def _memory_efficient_attention_backward(
ctx: Context,
inp: Inputs,
grad: torch.Tensor,
op: Optional[Type[AttentionBwOpBase]],
*,
_skip_op_checks: bool = False,
) -> Gradients:
"""Warning: grad/ctx.out is potentially in BMK format"""
inp.validate_inputs()
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
raise ValueError(
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
f"grad.shape : {grad.shape} \n"
f"out.shape : {ctx.out.shape} \n"
f"query.shape: {inp.query.shape}"
)
shape_dq, shape_dk, shape_dv = tuple(
x.shape for x in (inp.query, inp.key, inp.value)
)
inp.normalize_bmhk()
# LSE has shape [B, H, M] while query has shape [B, M, H, K]
if (
ctx.lse.ndim != 3
# Dim 0
or (
not isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[0] != inp.query.shape[0]
)
or (
isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[0] != inp.attn_bias.q_seqinfo.seqstart.shape[0] - 1
)
# Dim 1
or ctx.lse.shape[1] != inp.query.shape[2]
# Dim 2
or (
not isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[2] < inp.query.shape[1]
)
):
raise ValueError(
"Input tensors have incompatible shapes."
f"lse.shape : {ctx.lse.shape} \n"
f"query.shape : {inp.query.shape}"
)
grad = bmk2bmhk(grad, 1)
ctx.out = bmk2bmhk(ctx.out, 1)
if op is None:
op = _dispatch_bw(inp)
elif not _skip_op_checks:
_ensure_op_supports_or_raise(
ValueError, "memory_efficient_attention_backward", op, inp
)
grads = op.apply(ctx, inp, grad)
grads.dq = grads.dq.reshape(shape_dq)
grads.dk = grads.dk.reshape(shape_dk)
grads.dv = grads.dv.reshape(shape_dv)
return grads
def memory_efficient_attention_partial(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple (output, lse), where `output` is the attention and `lse`
is a least squared error. The cat'ed outputs of calls to this with the same query
and separate keys and values can be merged with merge_attentions to obtain
the attention of the queries against the disjoint union of the keys and values.
"""
if p != 0.0:
raise NotImplementedError("dropout is not supported.")
if not isinstance(
attn_bias,
(
type(None),
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
PagedBlockDiagonalPaddedKeysMask,
LowerTriangularFromBottomRightMask,
LowerTriangularMask,
),
):
raise ValueError(
f"{type(attn_bias)} is not supported in memory_efficient_attention_partial."
)
out, ctx = _memory_efficient_attention_forward_requires_grad(
Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
is_partial=True,
),
op=op,
)
return out, ctx.lse
def merge_attentions(
attn_split: Union[torch.Tensor, List[torch.Tensor]],
lse_split: Union[torch.Tensor, List[torch.Tensor]],
write_lse: bool = True,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Combine attention output computed on different parts of K/V for the same
query to get attention on the whole K/V. See https://arxiv.org/abs/2402.05099
The result is equal to
Out_full = (Out1 * exp(LSE1) + Out2 * exp(LSE2) + ...) / (exp(LSE1) + exp(LSE2) + ...)
LSE_full = log(exp(LSE1) + exp(LSE2) + ...)
Args:
attn_split: attention outputs for chunks,
either as a list of tensors of shapes [B, M, G, H, Kq] or [B, M, H, Kq]
or as a single tensor of shape [num_chunks, B, M, G, H, Kq]
or [num_chunks, B, M, H, Kq]
lse_split: LSE for chunks,
either as a list of tensors of shapes [B, G, H, M] or [B, H, M]
or as a single tensor of shape [num_chunks, B, G, H, M] or [num_chunks, B, H, M]
write_lse: whether to output LSE
out_dype: dtype of attn_out
Returns:
attn_out: [B, M, G, H, Kq] or [B, M, H, Kq]
lse_out: [B, G, H, M] or [B, H, M] if write_lse
or None otherwise
"""
attn_is_concat = isinstance(attn_split, torch.Tensor)
lse_is_concat = isinstance(lse_split, torch.Tensor)
concat_path = attn_is_concat and lse_is_concat
if not concat_path:
if attn_is_concat:
attn_split = cast(torch.Tensor, attn_split).unbind(0)
if lse_is_concat:
lse_split = cast(torch.Tensor, lse_split).unbind(0)
if concat_path:
attn_split = cast(torch.Tensor, attn_split)
lse_split = cast(torch.Tensor, lse_split)
if attn_split.ndim != lse_split.ndim + 1:
raise ValueError(
f"Incompatible input shapes: {attn_split.shape=}, {lse_split.shape=}"
)
is_bmhk = attn_split.ndim == 5
if is_bmhk:
attn_split = attn_split.unsqueeze(3)
lse_split = lse_split.unsqueeze(2)
num_chunks, B, M, G, H, Kq = attn_split.shape
num_chunks1, B1, G1, H1, M1 = lse_split.shape
if B != B1 or G != G1 or H != H1 or num_chunks != num_chunks1 or M != M:
raise ValueError(
f"Incompatible input shapes: {attn_split.shape=} {lse_split.shape=} "
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {num_chunks}/{num_chunks1}, {M}/{M}"
)
attn_split = attn_split.permute(1, 3, 4, 0, 2, 5)
lse_split = lse_split.permute(1, 2, 3, 0, 4)
device = attn_split.device
attn_dtype = attn_split.dtype
lse_dtype = lse_split.dtype
merge_func: Any = triton_splitk.merge_attentions
else:
num_chunks = len(attn_split)
if len(lse_split) != num_chunks:
raise ValueError(
f"Incompatible number of LSE and attention chunks: {len(attn_split)=}, {len(lse_split)=}"
)
attn_unsqueezed = []
lse_unsqueezed = []
is_bmhk = False
for i in range(num_chunks):
if attn_split[i].ndim != lse_split[i].ndim + 1:
raise ValueError(
f"Incompatible input shapes for chunk {i}: {attn_split[i].shape=}, {lse_split[i].shape=}"
)
is_bmhk = attn_split[i].ndim == 4
if is_bmhk:
attn_unsqueezed.append(attn_split[i].unsqueeze(2))
lse_unsqueezed.append(lse_split[i].unsqueeze(1))
else:
attn_unsqueezed.append(attn_split[i])
lse_unsqueezed.append(lse_split[i])
attn_split, lse_split = attn_unsqueezed, lse_unsqueezed
B, M, G, H, Kq = attn_split[0].shape
B1, G1, H1, M1 = lse_split[0].shape
if B != B1 or G != G1 or H != H1 or M != M:
raise ValueError(
f"Incompatible input shapes: {attn_split[0].shape=}, {lse_split[0].shape=} "
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {M}/{M}"
)
for i in range(num_chunks):
if attn_split[i].shape != (B, M, G, H, Kq):
raise ValueError(
f"Incompatible input shapes for attention chunk {i}: "
f"{attn_split[i].shape=}, {(B, M, G, H, Kq)=}"
)
if lse_split[i].shape != (B, G, H, M):
raise ValueError(
f"Incompatible input shapes for LSE chunk {i}: "
f"{lse_split[i].shape=}, {(B, G, H, M)=}"
)
attn_split[i] = attn_split[i].permute(0, 2, 3, 1, 4) # to (B, G, H, M, Kq)
device = attn_split[0].device
attn_dtype = attn_split[0].dtype
lse_dtype = lse_split[0].dtype
merge_func = triton_splitk.merge_attentions_varargs
attn_out = torch.empty(
B,
M,
G,
H,
Kq,
device=device,
dtype=output_dtype or attn_dtype,
)
if write_lse:
lse_out = torch.empty(B, G, H, M, device=device, dtype=lse_dtype)
else:
lse_out = None
merge_func(attn_out, lse_out, attn_split, lse_split) # type: ignore
if is_bmhk:
attn_out = attn_out[:, :, 0]
if lse_out is not None:
lse_out = lse_out[:, 0]
return attn_out, lse_out
ALL_FW_OPS: Sequence[Type[AttentionFwOpBase]] = [
cutlass.FwOp if torch.version.cuda else ck.FwOp,
flash.FwOp,
small_k.FwOp,
triton_splitk.FwOp,
]
ALL_BW_OPS: Sequence[Type[AttentionBwOpBase]] = [
cutlass.BwOp if torch.version.cuda else ck.BwOp,
flash.BwOp,
small_k.BwOp,
]
__all__ = [
"AttentionBias",
"AttentionOp",
"AttentionOpBase",
"AttentionOpDispatch",
"LowerTriangularMask",
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
"MemoryEfficientAttentionCutlassOp",
"MemoryEfficientAttentionFlashAttentionOp",
"MemoryEfficientAttentionOp",
"memory_efficient_attention",
"MemoryEfficientAttentionCkOp",
"MemoryEfficientAttentionCkDecoderOp",
"ALL_FW_OPS",
"ALL_BW_OPS",
"attn_bias",
]