# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Optional, Sequence, Tuple, Type, Union, cast
import torch
from . import (
attn_bias,
ck,
ck_decoder,
ck_splitk,
cutlass,
flash,
flash3,
triton_splitk,
)
from .attn_bias import (
VARLEN_BIASES,
AttentionBias,
BlockDiagonalMask,
LowerTriangularMask,
)
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
AttentionOp,
AttentionOpBase,
Context,
Gradients,
Inputs,
bmk2bmhk,
)
from .dispatch import (
_dispatch_bw,
_dispatch_fw,
_ensure_op_supports_or_raise,
_get_use_fa3,
_set_use_fa3,
)
MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
MemoryEfficientAttentionCkOp = (ck.FwOp, ck.BwOp)
MemoryEfficientAttentionCkDecoderOp = (ck_decoder.FwOp, ck.BwOp)
MemoryEfficientAttentionSplitKCkOp = (ck_splitk.FwOp, ck.BwOp)
def _deserialize_bias(attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]) -> Any:
if attn_bias_tensor is None:
return attn_bias_ctx
return attn_bias_tensor
# Note: `torch.compile` only allows custom autograd functions
# to accept a subset of types. Therefore we serialize `op` objects
# to `str` before entering the function, and unserialize them inside.
# See also: https://github.com/pytorch/pytorch/issues/118395
_OPS_LOOKUP = {
flash.FwOp.NAME: flash.FwOp,
flash.BwOp.NAME: flash.BwOp,
}
def _serialize_op(op):
if op is not None and op.NAME in _OPS_LOOKUP:
return op.NAME
return op
def _unserialize_op(op):
if isinstance(op, str):
return _OPS_LOOKUP[op]
return op
class _fMHA(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, op_fw, op_bw, *args: Any) -> Any:
inp = Inputs(*args)
op_fw = _unserialize_op(op_fw)
op_bw = _unserialize_op(op_bw)
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
inp=inp, op=op_fw
)
# Saving attn_bias is a bit complicated, as the
# torch part should go in `save_for_backward`
if isinstance(inp.attn_bias, torch.Tensor):
attn_bias_tensor = inp.attn_bias
attn_bias_ctx = None
else:
attn_bias_tensor = None
attn_bias_ctx = inp.attn_bias
ctx.save_for_backward(
inp.query,
inp.key,
inp.value,
op_ctx.out,
op_ctx.lse,
)
ctx.rng_state = op_ctx.rng_state
ctx.attn_bias_tensor = attn_bias_tensor
if op_ctx.op_bw is not None:
if op_bw is not None and op_bw is not op_ctx.op_bw:
raise ValueError(
f"Specified op_bw={op_bw.NAME}, but forward op "
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
)
op_bw = op_ctx.op_bw
if (
op_bw is not None
and isinstance(inp.attn_bias, VARLEN_BIASES)
and inp.attn_bias.q_seqinfo.seqstart.shape[0] > 2
and op_bw.VARLEN_LSE_PACKED != op_fw.VARLEN_LSE_PACKED
):
raise ValueError(
f"Specified op_bw={op_bw.NAME} is not compatible with the "
f"op_fw={op_fw.NAME}, because they use different format of logsumexp. "
f"NOTE: This is new with xFormers 0.0.28"
)
if op_bw is None and (
inp.query.requires_grad or inp.key.requires_grad or inp.value.requires_grad
):
varlen_lse_packed = _detect_lse_packed_or_raise(op_ctx.lse, inp)
if varlen_lse_packed is not None and op_fw is not None:
assert (
op_fw.VARLEN_LSE_PACKED == varlen_lse_packed
), f"{op_fw.NAME}: wrong value for `VARLEN_LSE_PACKED` ?"
# NOTE: We need to check tensor strides to decide which operator we run in the BW pass.
# Unfortunately, PyTorch only allows to call this function during the FW pass, so
# we decide the operator to use now.
op_bw = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
ctx.op_fw = op_fw
ctx.op_bw = op_bw
ctx.p = inp.p
# This allows to create gradients from a single storage,
# to avoid a "cat" in the BW pass.
# The heuristic is approximative, but:
# (1) It's not a big issue to create a shared storage
# (2) The heuristic needs to pass `torch.compile`
# (this is also why we run it in the FW pass, the BW pass is stricter)
ctx.qkv_share_storage = (
inp.query.shape[0] == inp.key.shape[0]
and inp.query.shape[-1] == inp.value.shape[-1]
and inp.query.stride(-2)
== (inp.key.shape[-1] + inp.query.shape[-1] + inp.value.shape[-1])
)
ctx.scale = inp.scale
ctx.attn_bias_ctx = attn_bias_ctx
ctx.n_args = len(args)
return out, op_ctx.lse
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad, grad_lse):
# Re-create context
query, key, value, out, lse = ctx.saved_tensors
attn_bias_tensor = ctx.attn_bias_tensor
rng_state = ctx.rng_state
inp = Inputs(
query=query,
key=key,
value=value,
attn_bias=_deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
p=ctx.p,
scale=ctx.scale,
)
op_ctx = Context(
lse=lse,
out=out,
rng_state=rng_state,
)
grads = _memory_efficient_attention_backward(
ctx=op_ctx,
inp=inp,
grad=grad,
op=ctx.op_bw,
_skip_op_checks=True,
)
return (None, None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
ctx.n_args - 2
)
[docs]def memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[AttentionOp] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Implements the memory-efficient attention mechanism following
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
:Inputs shape:
- Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
the sequence length, H the number of heads, and K the embeding size per head
- If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
- Inputs can also be of dimension 5 with GQA - see note below
- Inputs can be non-contiguous - we only require the last dimension's stride to be 1
:Equivalent pytorch code:
.. code-block:: python
scale = 1.0 / query.shape[-1] ** 0.5
query = query * scale
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
attn = attn @ value
return attn.transpose(1, 2).contiguous()
:Examples:
.. code-block:: python
import xformers.ops as xops
# Compute regular attention
y = xops.memory_efficient_attention(q, k, v)
# With a dropout of 0.2
y = xops.memory_efficient_attention(q, k, v, p=0.2)
# Causal attention
y = xops.memory_efficient_attention(
q, k, v,
attn_bias=xops.LowerTriangularMask()
)
:Supported hardware:
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
:EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
MQA/GQA is an experimental feature supported only for the forward pass.
If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
``H`` is the number of heads per group (8 in the example).
Please note that xFormers will not automatically broadcast the inputs, so you will need
to broadcast it manually before calling `memory_efficient_attention`.
:GQA/MQA example:
.. code-block:: python
import torch
import xformers.ops as xops
B, M, K = 3, 32, 128
kwargs = dict(device="cuda", dtype=torch.float16)
q = torch.randn([B, M, 8, K], **kwargs)
k = torch.randn([B, M, 2, K], **kwargs)
v = torch.randn([B, M, 2, K], **kwargs)
out_gqa = xops.memory_efficient_attention(
q.reshape([B, M, 2, 4, K]),
k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
)
Raises:
NotImplementedError: if there is no operator available to compute the MHA
ValueError: if inputs are invalid
:parameter query: Tensor of shape ``[B, Mq, H, K]``
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
:parameter p: Dropout probability. Disabled if set to ``0.0``
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
scale (q.shape[-1]**-0.5) will be used.
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
If set to ``None`` (recommended), xFormers \
will dispatch to the best available operator, depending on the inputs \
and options.
:return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
"""
return _memory_efficient_attention(
Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
),
op=op,
)
torch.library.define(
"xformer::memory_efficient_attention_forward",
"(Tensor q, Tensor k, Tensor v, Tensor? b = None, float? p = 0.0, float? scale = None) -> Tensor",
)
@torch.library.impl("xformer::memory_efficient_attention_forward", "Meta")
def memory_efficient_attention_forward_meta(q, k, v):
return q.new_empty(q.shape)
# torch.compile has issue when tracing through op dispatch and ensure_op_support
# so provide a wrapper to register it as a custom torch library op.
@torch.library.impl("xformer::memory_efficient_attention_forward", "CUDA")
def memory_efficient_attention_forward_torch_wrapper(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
"""
This provides a torch-compilable wrapper op to
memory_efficient_attention_forward in certain special cases.
Note that the following are not supported
- `op` input (?)
- certain attn_bias types (?)
- output_dtype
- K != Kv
"""
return memory_efficient_attention_forward(
query,
key,
value,
attn_bias,
p,
scale,
)
[docs]def memory_efficient_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
"""
return _memory_efficient_attention_forward(
Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
),
op=op,
)
[docs]def memory_efficient_attention_forward_requires_grad(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
"""
if p != 0.0:
raise NotImplementedError(
"dropout is not supported on the non-autograd API."
" If you want to use dropout, please call `memory_efficient_attention` directly"
)
out, ctx = _memory_efficient_attention_forward_requires_grad(
Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
),
op=op,
)
return out, ctx.lse
[docs]def memory_efficient_attention_backward(
grad: torch.Tensor,
output: torch.Tensor,
lse: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionBwOpBase]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the gradient of the attention.
Returns a tuple (dq, dk, dv)
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
`lse` is the tensor returned by
:attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
"""
if p != 0.0:
raise NotImplementedError(
"dropout is not supported on the non-autograd API."
" If you want to use dropout, please call `memory_efficient_attention` directly"
)
gradients = _memory_efficient_attention_backward(
Context(out=output, lse=lse),
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
),
grad,
op=op,
)
return (gradients.dq, gradients.dk, gradients.dv)
def _memory_efficient_attention(
inp: Inputs, op: Optional[AttentionOp] = None
) -> torch.Tensor:
# fast-path that doesn't require computing the logsumexp for backward computation
if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
return _memory_efficient_attention_forward(
inp, op=op[0] if op is not None else None
)
output_shape = inp.normalize_bmhk()
op_fw = _serialize_op(op[0] if op is not None else None)
op_bw = _serialize_op(op[1] if op is not None else None)
return _fMHA.apply(
op_fw, op_bw, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
)[0].reshape(output_shape)
def _memory_efficient_attention_forward(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> torch.Tensor:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp, False)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
out, *_ = op.apply(inp, needs_gradient=False)
return out.reshape(output_shape)
def _memory_efficient_attention_forward_requires_grad(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> Tuple[torch.Tensor, Context]:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp, True)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
out = op.apply(inp, needs_gradient=True)
assert out[1] is not None
return (out[0].reshape(output_shape), out[1])
def _detect_lse_packed_or_raise(lse: torch.Tensor, inp: Inputs) -> Optional[bool]:
"""
Detects the LSE format if we're in a varlen case.
Returns `None` if the format is not relevant (eg not varlen)
Raises an exception if the `lse` has the wrong shape
"""
shape_mismatch_err = (
"Input tensors have incompatible shapes.\n"
f" lse.shape : {lse.shape}\n"
f" query.shape : {inp.query.shape}\n"
f" attn_bias : {type(inp.attn_bias)}"
)
# 1. Check ndim & head dimensions
# In any case, LSE should be [*, *GH]
if lse.ndim != (inp.query.ndim - 1) or lse.shape[1:-1] != inp.query.shape[2:-1]:
raise ValueError(shape_mismatch_err)
lse_bm = [lse.shape[0], lse.shape[-1]]
lse_packed_shape = [inp.query.shape[0], inp.query.shape[1]]
lse_packed = lse_bm[0] == lse_packed_shape[0] and lse_bm >= lse_packed_shape
# 2. Check correctness for varlen biases with query.shape = [1, M, *GH, K]
# Either [1, *GH, M] (packed)
# Or [num_seq, *GH, Mq] .. with `Mq >= max_q` (padded)
if isinstance(inp.attn_bias, VARLEN_BIASES):
si = inp.attn_bias.q_seqinfo
lse_padded_shape = [si.seqstart.shape[0] - 1, si.max_seqlen]
lse_padded = lse_bm[0] == lse_padded_shape[0] and lse_bm >= lse_padded_shape
if lse_packed and lse_padded:
return None
elif lse_packed:
return True
elif lse_padded:
return False
raise ValueError(shape_mismatch_err)
# 3. For non-varlen, shape must be [B, *GH] with query.shape=[B, M, *GH, K]
if not lse_packed:
raise ValueError(shape_mismatch_err)
return None
def _memory_efficient_attention_backward(
ctx: Context,
inp: Inputs,
grad: torch.Tensor,
op: Optional[Type[AttentionBwOpBase]],
*,
_skip_op_checks: bool = False,
) -> Gradients:
"""Warning: grad/ctx.out is potentially in BMK format"""
inp.validate_inputs()
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
raise ValueError(
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
f"grad.shape : {grad.shape} \n"
f"out.shape : {ctx.out.shape} \n"
f"query.shape: {inp.query.shape}"
)
shape_dq, shape_dk, shape_dv = tuple(
x.shape for x in (inp.query, inp.key, inp.value)
)
inp.normalize_bmhk()
varlen_lse_packed = _detect_lse_packed_or_raise(ctx.lse, inp)
grad = bmk2bmhk(grad, 1)
ctx.out = bmk2bmhk(ctx.out, 1)
if op is None:
op = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
elif not _skip_op_checks:
_ensure_op_supports_or_raise(
ValueError, "memory_efficient_attention_backward", op, inp
)
if varlen_lse_packed is not None and varlen_lse_packed != op.VARLEN_LSE_PACKED:
raise ValueError(
f"Wrong LSE format for {op.NAME} in variable seqlen case. "
f"Double-check that the BW operator {op.NAME} is compatible "
f"with the operator used in the FW pass."
)
grads = op.apply(ctx, inp, grad)
grads.dq = grads.dq.reshape(shape_dq)
grads.dk = grads.dk.reshape(shape_dk)
grads.dv = grads.dv.reshape(shape_dv)
return grads
[docs]def memory_efficient_attention_partial(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Union[AttentionOp, Type[AttentionFwOpBase]]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple (output, lse), where `output` is the attention in the style of
memory_efficient_attention, and `lse` is extra data, a log-sum-exp.
The outputs of calls to this with the same query and separate keys and values
can be merged with merge_attentions to obtain the attention of the queries
against the disjoint union of the keys and values.
Warning: The backward pass of this function is quite restricted. In particular
we assume that in the forward pass the outputs were only used in merge_attention
calculations, and that LSEs weren't used anywhere except in merge attentions.
"""
if p != 0.0:
raise NotImplementedError("dropout is not supported.")
fwop: Optional[Type[AttentionFwOpBase]] = op[0] if isinstance(op, tuple) else op
inp = Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
is_partial=True,
)
is_grad = torch.is_grad_enabled() and any(
x.requires_grad for x in [query, key, value]
)
if not is_grad:
out, ctx = _memory_efficient_attention_forward_requires_grad(
inp,
op=fwop,
)
return out, ctx.lse
if query.ndim == 5:
raise ValueError("gradients not supported for 5D tensors")
if isinstance(op, tuple):
op_fw = _serialize_op(op[0])
op_bw = _serialize_op(op[1])
elif op is None:
op_fw = op_bw = None
else:
op_fw = _serialize_op(op)
op_bw = None
return _fMHA.apply(
op_fw,
op_bw,
inp.query,
inp.key,
inp.value,
inp.attn_bias,
inp.p,
inp.scale,
inp.output_dtype,
inp.is_partial,
)
[docs]def merge_attentions(
attn_split: Union[torch.Tensor, Sequence[torch.Tensor]],
lse_split: Union[torch.Tensor, Sequence[torch.Tensor]],
write_lse: bool = True,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Combine attention output computed on different parts of K/V for the same
query to get attention on the whole K/V. See https://arxiv.org/abs/2402.05099
The result is equal to
Out_full = (Out1 * exp(LSE1) + Out2 * exp(LSE2) + ...) / (exp(LSE1) + exp(LSE2) + ...)
LSE_full = log(exp(LSE1) + exp(LSE2) + ...)
Args:
attn_split: attention outputs for chunks,
either as a list of tensors of shapes [B, M, G, H, Kq] or [B, M, H, Kq]
or as a single tensor of shape [num_chunks, B, M, G, H, Kq]
or [num_chunks, B, M, H, Kq]
lse_split: LSE for chunks,
either as a list of tensors of shapes [B, G, H, M] or [B, H, M]
or as a single tensor of shape [num_chunks, B, G, H, M] or [num_chunks, B, H, M]
write_lse: whether to output LSE
output_dtype: dtype of attn_out
Returns:
attn_out: [B, M, G, H, Kq] or [B, M, H, Kq]
lse_out: [B, G, H, M] or [B, H, M] if write_lse
or None otherwise
"""
attn_is_concat = isinstance(attn_split, torch.Tensor)
lse_is_concat = isinstance(lse_split, torch.Tensor)
attn_requires_grad = (
attn_split.requires_grad # type: ignore
if attn_is_concat
else any(x.requires_grad for x in attn_split)
)
lse_requires_grad = (
lse_split.requires_grad # type: ignore
if lse_is_concat
else any(x.requires_grad for x in lse_split)
)
requires_grad = torch.is_grad_enabled() and (
attn_requires_grad or lse_requires_grad
)
if requires_grad and not write_lse:
raise ValueError("write_lse should be true if inputs require gradients.")
concat_path = attn_is_concat and lse_is_concat and not requires_grad
if concat_path:
attn_split = cast(torch.Tensor, attn_split)
lse_split = cast(torch.Tensor, lse_split)
if attn_split.ndim != lse_split.ndim + 1:
raise ValueError(
f"Incompatible input shapes: {attn_split.shape=}, {lse_split.shape=}"
)
is_bmhk = attn_split.ndim == 5
if is_bmhk:
attn_split = attn_split.unsqueeze(3)
lse_split = lse_split.unsqueeze(2)
num_chunks, B, M, G, H, Kq = attn_split.shape
num_chunks1, B1, G1, H1, M1 = lse_split.shape
if B != B1 or G != G1 or H != H1 or num_chunks != num_chunks1 or M != M:
raise ValueError(
f"Incompatible input shapes: {attn_split.shape=} {lse_split.shape=} "
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {num_chunks}/{num_chunks1}, {M}/{M}"
)
attn_split = attn_split.permute(1, 3, 4, 0, 2, 5)
lse_split = lse_split.permute(1, 2, 3, 0, 4)
device = attn_split.device
attn_dtype = attn_split.dtype
lse_dtype = lse_split.dtype
else:
if attn_is_concat:
attn_split = attn_split.unbind(0) # type: ignore
if lse_is_concat:
lse_split = lse_split.unbind(0) # type: ignore
num_chunks = len(attn_split)
if len(lse_split) != num_chunks:
raise ValueError(
f"Incompatible number of LSE and attention chunks: {len(attn_split)=}, {len(lse_split)=}"
)
attn_unsqueezed = []
lse_unsqueezed = []
is_bmhk = False
for i in range(num_chunks):
if attn_split[i].ndim != lse_split[i].ndim + 1:
raise ValueError(
f"Incompatible input shapes for chunk {i}: {attn_split[i].shape=}, {lse_split[i].shape=}"
)
is_bmhk = attn_split[i].ndim == 4
if is_bmhk:
attn_unsqueezed.append(attn_split[i].unsqueeze(2))
lse_unsqueezed.append(lse_split[i].unsqueeze(1))
else:
attn_unsqueezed.append(attn_split[i])
lse_unsqueezed.append(lse_split[i])
attn_split, lse_split = attn_unsqueezed, lse_unsqueezed
B, M, G, H, Kq = attn_split[0].shape
B1, G1, H1, M1 = lse_split[0].shape
if B != B1 or G != G1 or H != H1 or M != M:
raise ValueError(
f"Incompatible input shapes: {attn_split[0].shape=}, {lse_split[0].shape=} "
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {M}/{M}"
)
for i in range(num_chunks):
if attn_split[i].shape != (B, M, G, H, Kq):
raise ValueError(
f"Incompatible input shapes for attention chunk {i}: "
f"{attn_split[i].shape=}, {(B, M, G, H, Kq)=}"
)
if lse_split[i].shape != (B, G, H, M):
raise ValueError(
f"Incompatible input shapes for LSE chunk {i}: "
f"{lse_split[i].shape=}, {(B, G, H, M)=}"
)
attn_split[i] = attn_split[i].permute(0, 2, 3, 1, 4) # to (B, G, H, M, Kq)
device = attn_split[0].device
attn_dtype = attn_split[0].dtype
lse_dtype = lse_split[0].dtype
attn_out = torch.empty(
B,
M,
G,
H,
Kq,
device=device,
dtype=output_dtype or attn_dtype,
requires_grad=requires_grad,
)
if write_lse:
lse_out = torch.empty(
B, G, H, M, device=device, dtype=lse_dtype, requires_grad=requires_grad
)
else:
lse_out = None
if concat_path:
triton_splitk.merge_attentions(attn_out, lse_out, attn_split, lse_split) # type: ignore
else:
attn_out, lse_out = _MergeAttentions.apply(attn_out, lse_out, *attn_split, *lse_split) # type: ignore
if is_bmhk:
attn_out = attn_out[:, :, 0]
if lse_out is not None:
lse_out = lse_out[:, 0]
return attn_out, lse_out
class _MergeAttentions(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx, attn_out: torch.Tensor, lse_out: torch.Tensor, *inputs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
num_chunks = len(inputs) // 2
attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:]
triton_splitk.merge_attentions_varargs(attn_out, lse_out, attn_split, lse_split)
ctx.save_for_backward(
attn_out,
lse_out,
*inputs,
)
return attn_out, lse_out
@staticmethod
# type: ignore
def backward(
ctx, grad_attn: torch.Tensor, grad_lse: torch.Tensor
) -> Tuple[Optional[torch.Tensor], ...]:
out, lse, *inputs = ctx.saved_tensors
num_chunks = len(inputs) // 2
attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:]
dattn, dlse = triton_splitk.merge_attentions_varargs_backward(
attn_split,
lse_split,
out,
lse,
grad_attn,
grad_lse,
)
ret = [None, None] + dattn + dlse
return tuple(ret)
ALL_FW_OPS: List[Type[AttentionFwOpBase]] = [
cutlass.FwOp if torch.version.cuda else ck.FwOp,
flash.FwOp,
flash3.FwOp,
triton_splitk.FwOp,
]
ALL_BW_OPS: List[Type[AttentionBwOpBase]] = [
cutlass.BwOp if torch.version.cuda else ck.BwOp,
flash.BwOp,
flash3.BwOp,
]
__all__ = [
"AttentionBias",
"AttentionOp",
"AttentionOpBase",
"LowerTriangularMask",
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
"MemoryEfficientAttentionCutlassOp",
"MemoryEfficientAttentionFlashAttentionOp",
"memory_efficient_attention",
"MemoryEfficientAttentionCkOp",
"MemoryEfficientAttentionCkDecoderOp",
"ALL_FW_OPS",
"ALL_BW_OPS",
"attn_bias",
"_get_use_fa3",
"_set_use_fa3",
"BlockDiagonalMask",
]