Source code for xformers.ops.fmha.small_k
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union
import torch
from ..common import get_xformers_operator, register_operator
from .attn_bias import AttentionBias
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
bmk2bmhk,
)
def _bmhk2bmk_contiguous(tensor) -> torch.Tensor:
return (
tensor.permute((0, 2, 1, 3))
.contiguous()
.view([tensor.shape[0] * tensor.shape[2], tensor.shape[1], tensor.shape[3]])
.contiguous()
)
def _get_tensor_bias_bmk(
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
) -> Optional[torch.Tensor]:
if not isinstance(attn_bias, torch.Tensor):
assert attn_bias is None
return None
# BMK -> BMHK
if attn_bias.ndim == 4:
attn_bias = attn_bias.reshape([-1, *attn_bias.shape[2:]])
return attn_bias
[docs]@register_operator
class FwOp(AttentionFwOpBase):
"""An operator optimized for very small values of K (``K <= 32``) \
and f32 pre-Ampere as it does not use TensorCores.
Only supports contiguous inputs in BMK format, so an extra reshape \
or contiguous call might be done.
:Deprecated:
This operator is deprecated and should not be used in new code
"""
OPERATOR = get_xformers_operator("efficient_attention_forward_small_k")
SUPPORTED_DEVICES = {"cuda"}
SUPPORTED_DTYPES = {torch.float}
SUPPORTED_MAX_K: float = 32
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (type(None), torch.Tensor)
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = False
NAME = "smallkF"
BACKWARD_ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 4e-3,
}
# as this kernel is a bit slow, this should make tests run faster
_TEST_BATCH_SIZES = [1, 3]
_TEST_K = [2, 3, 8, 16, 32]
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
reasons.append("bias with non-zero stride not supported")
buffer_size = 8
k = d.query.shape[-1]
for pack in [1, 2, 4]:
if (k % pack) == 0 and (k // pack) <= buffer_size:
return reasons
reasons.append(f"unsupported embed per head: {k}")
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if inp.scale is not None:
raise NotImplementedError("Unsupport custom scale")
num_heads = inp.query.shape[2]
query = _bmhk2bmk_contiguous(inp.query)
key = _bmhk2bmk_contiguous(inp.key)
value = _bmhk2bmk_contiguous(inp.value)
out, lse, rng_seed, rng_offset = cls.OPERATOR(
query=query,
key=key,
value=value,
compute_logsumexp=needs_gradient,
attn_bias=_get_tensor_bias_bmk(inp.attn_bias),
p=inp.p,
)
out = bmk2bmhk(out, num_heads)
lse = lse.reshape([lse.shape[0] // num_heads, num_heads, lse.shape[1]])
if not needs_gradient:
return out, None
ctx = Context(out=out, lse=lse)
if inp.p != 0.0:
ctx.op_bw = BwOp
ctx.rng_state = torch.tensor(
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
)
return out, ctx
[docs]@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = get_xformers_operator("efficient_attention_backward_small_k")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
# there is some extra precision loss in the CPU implementation due to an
# extra accumulation step in grad_q, which is not present in the CUDA
# implementation
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 4e-3,
}
NAME = "smallkB"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
reasons.append("bias with non-zero stride not supported")
buffer_size = 8
k = d.query.shape[-1]
for pack in [1, 2, 4]:
if (k % pack) == 0 and (k // pack) <= buffer_size:
return reasons
reasons.append(f"unsupported embed per head: {k}")
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
num_heads = grad.shape[2]
grad = _bmhk2bmk_contiguous(grad)
query = _bmhk2bmk_contiguous(inp.query)
key = _bmhk2bmk_contiguous(inp.key)
value = _bmhk2bmk_contiguous(inp.value)
out = _bmhk2bmk_contiguous(ctx.out)
rng_seed = rng_offset = 0
if inp.p != 0.0:
if (
ctx.rng_state is None
or ctx.rng_state.dtype != torch.int64
or ctx.rng_state.device.type != "cpu"
or ctx.rng_state.shape != (2,)
):
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
rng_seed, rng_offset = ctx.rng_state.tolist()
grad_q, grad_k, grad_v = cls.OPERATOR(
grad,
query,
key,
value,
# LSE: BHM -> (BH)M
ctx.lse.flatten(end_dim=-2),
out,
_get_tensor_bias_bmk(inp.attn_bias),
inp.p,
rng_seed,
rng_offset,
)
return Gradients(
dq=bmk2bmhk(grad_q, num_heads),
dk=bmk2bmhk(grad_k, num_heads),
dv=bmk2bmhk(grad_v, num_heads),
)