
Source code for xformers.ops.fmha.ck_decoder

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Iterable, List, Optional, Set, Tuple

import torch

from ..common import get_xformers_operator, register_operator
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
from .common import AttentionFwOpBase, Context, Inputs

[docs]@register_operator class FwOp(AttentionFwOpBase): """ An operator optimized for K=256 (so the contiguous dim fits into registers). Tested to work on MI250x. """ OPERATOR = get_xformers_operator("efficient_attention_forward_decoder_ck") SUPPORTED_DEVICES: Set[str] = {"cuda"} SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16, torch.float} SUPPORTED_MAX_K: int = 256 SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( type(None), BlockDiagonalCausalWithOffsetPaddedKeysMask, ) SUPPORTS_DROPOUT = False SUPPORTS_CUSTOM_SCALE = True SUPPORTS_BMGHK = True NAME = "ck_decoderF" @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(FwOp, cls).not_supported_reasons(d) attn_bias = d.attn_bias if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): if d.query.shape[0] != 1: reasons.append( f"One formal batch element expected; got {d.query.shape[0]}" ) if d.query.shape[-1] > cls.SUPPORTED_MAX_K: reasons.append( f"Got head_dim={d.query.shape[-1]}; only head_dim<={cls.SUPPORTED_MAX_K} is supported for now." ) threads_per_warp = 64 # TODO: ideally query the platform here required_alignment = 0 head_dim = d.query.shape[-1] for vec_size in (4, 2, 1): if head_dim <= vec_size * threads_per_warp: required_alignment = vec_size if not required_alignment: reasons.append(f"Got head_dim={head_dim} which is too large") if head_dim % required_alignment != 0: reasons.append( f"Got head_dim={head_dim}; it needs to be divisible by {required_alignment}" ) if d.key.stride(-1) != 1: reasons.append("expect keys to have last dim contiguous") if d.value.stride(-1) != 1: reasons.append("expect values to have last dim contiguous") q_starts = attn_bias.q_seqinfo.seqstart_py padding = attn_bias.k_seqinfo.padding bsz = d.key.shape[1] // padding num_queries = d.query.shape[1] // bsz if q_starts != list(range(0, 1 + bsz, num_queries)): reasons.append("expect to have same num_queries in each batch") if bsz != len(q_starts) - 1: reasons.append("empty lanes not supported yet") if attn_bias.k_seqinfo.padding > 8192: reasons.append("key padding exceeds 8192") return reasons @classmethod def apply( cls, inp: Inputs, needs_gradient: bool ) -> Tuple[torch.Tensor, Optional[Context]]: if needs_gradient: raise NotImplementedError("backward pass is not supported") attn_bias = inp.attn_bias q, k, v = inp.get_qkv_in_bmghk() if attn_bias is not None: assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) padding = attn_bias.k_seqinfo.padding seq_positions_gpu = attn_bias.k_seqinfo.seqlen else: padding = k.shape[1] seq_positions_gpu = None if attn_bias is not None: # key: (1, B * padding, G, 1 if multiquery else Hkv, D) # value: like key # query: (1, B * q_seqlen, G, Hq, D) multiquery = k.stride(3) == 0 if multiquery: key = k[0, :, :, :1].unflatten(0, (-1, padding)) value = v[0, :, :, :1].unflatten(0, (-1, padding)) else: key = k[0].unflatten(0, (-1, padding)) value = v[0].unflatten(0, (-1, padding)) query = q[0].unflatten(0, (key.shape[0], -1)) else: # key: (B, padding, G, 1 if multiquery else Hkv, D) # value: like key # query: (B, q_seqlen, G, Hq, D) key = k query = q value = v if inp.scale is not None: qk_scale = inp.scale else: qk_scale = torch.rsqrt( torch.tensor(key.shape[-1], dtype=torch.float32) ).item() out = cls.OPERATOR( query=query, key=key, value=value, seq_positions=seq_positions_gpu, scale=qk_scale, ) return out, None