Source code for xformers.components.attention.base
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar
import torch
import torch.nn as nn
from xformers._deprecation_warning import deprecated_function
from xformers.components.attention import AttentionMask
@dataclass
class AttentionConfig:
"""Parameters required for all Attentions.
Can accept and store extra parameters.
"""
name: str # the registered name for this attention mechanism
dropout: float # dropout probability
Self = TypeVar("Self", bound="Attention")
# Define the common interface, every attention block needs to derive from it
[docs]class Attention(nn.Module, metaclass=ABCMeta):
r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""
_causal_mask: Optional[AttentionMask] = None
@abstractmethod
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
super().__init__()
deprecated_function(self)
# Requires the inputs to be projected
self.requires_input_projection = True
# Whether the head dimension needs to be present (if not it can be folded into the batch dimension)
self.requires_head_dimension = False
# key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask
self.requires_separate_masks = False
# Requires that K and Q have the same sequence length
self.requires_same_k_q_dimensions = False
# Whether the attention owns the single head/multihead mechanism
# so that the MHA wrapper should skip it
self.requires_skip_multi_head = False
# This attention requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = False
# Whether this attention mechanism supports attention masks
self.supports_attention_mask = True
self.supports_key_padding_mask = False
[docs] @classmethod
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)
[docs] @abstractmethod
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor):
"""
If the sequence is shorter than the mask, return a padded view
"""
if x.shape[-2] != mask.shape[-1]:
assert x.shape[-2] < mask.shape[-1], (
"Sequence is bigger than the provided mask, cannot infer what to do with it."
" Please update your attention mask"
)
pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0)
return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0)
return x