# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.core import (
scaled_dot_product_attention,
scaled_query_key_softmax,
)
from xformers.components.attention.utils import (
bool_mask_to_additive,
iterative_pinv,
reshape_key_padding_mask,
)
logger = logging.getLogger("xformers")
@dataclass
class NystromSelfAttentionConfig(AttentionConfig):
"""
num_heads Number of heads.
num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good
approximation according to https://arxiv.org/pdf/2102.03902.pdf.
causal Apply a causal mask, in that the attention cannot be applied to the future.
use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose
inverse, otherwise use standard torch inverse.
pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using
method from (Razavi et al. 2014).
False if using exact coefficient computation (leads to faster convergence).
inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse.
v_skip_connection A module that will take V as input and will be added as a skip connection to the
softmax approximation. A skip connection is added in the paper to help with training.
conv_kernel_size Kernel size for convolution optionally added to help in training.
If v_skip_connection is not specified, this will be used to define the default
depth wise convolution used as a skip connection.
If both conv_kernel_size and v_skip_connection are None, no skip connection will
be added.
landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d.
"""
num_heads: int
num_landmarks: Optional[int]
landmark_pooling: Optional[nn.Module]
causal: Optional[bool]
pinverse_original_init: Optional[bool]
inv_iterations: Optional[int]
v_skip_connection: Optional[nn.Module]
conv_kernel_size: Optional[int]
use_razavi_pinverse: Optional[bool]
class AvgPool(nn.Module):
def __init__(self, n: int):
super().__init__()
self.n = n
def forward(self, x: torch.Tensor):
# Average independently for every segment in the sequence dimension
seq_len = x.shape[1]
head_dim = x.shape[2]
segments = seq_len // self.n
assert segments > 0, "num_landmarks should be smaller than the sequence length"
# Dimensions are a match
if seq_len % self.n == 0:
return x.reshape(
-1,
self.n,
segments,
head_dim,
).mean(dim=-2)
# Handle the last segment boundary being off
n_round = self.n - seq_len % self.n
x_avg_round = (
x[:, : n_round * segments, :]
.reshape(-1, n_round, segments, head_dim)
.mean(dim=-2)
)
x_avg_off = (
x[:, n_round * segments :, :]
.reshape(-1, self.n - n_round, segments + 1, head_dim)
.mean(dim=-2)
)
return torch.cat((x_avg_round, x_avg_off), dim=-2)
[docs]@register_attention("nystrom", NystromSelfAttentionConfig)
class NystromAttention(Attention):
# TODO: update defaults for use_razavi_pinverse and inv_iterations
[docs] def __init__(
self,
dropout: float,
num_heads: int,
num_landmarks: int = 64,
landmark_pooling: Optional[nn.Module] = None,
causal: bool = False,
use_razavi_pinverse: bool = True,
pinverse_original_init: bool = False,
inv_iterations: int = 6, # recommended default in paper was 6.
v_skip_connection: Optional[nn.Module] = None,
conv_kernel_size: Optional[int] = None,
*args,
**kwargs,
):
"""
Nystrom attention mechanism, from Nystromformer_.
::
"A Nystrom-based Algorithm for Approximating Self-Attention."
Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021)
Reference codebase: https://github.com/mlpen/Nystromformer
.. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf
"""
super().__init__()
# merged key padding mask and attention mask is not accepted
self.requires_separate_masks = True
self.num_landmarks = num_landmarks
# TODO: should be able to not have to pass in num_heads
self.num_heads = num_heads
self.use_razavi_pinverse = use_razavi_pinverse
self.pinverse_original_init = pinverse_original_init
self.inv_iterations = inv_iterations
self.attn_drop = nn.Dropout(dropout)
self.skip_connection = v_skip_connection
self.causal = causal
if self.skip_connection is None and conv_kernel_size is not None:
self.skip_connection = nn.Conv2d(
in_channels=self.num_heads,
out_channels=self.num_heads,
kernel_size=(conv_kernel_size, 1),
padding=(conv_kernel_size // 2, 0),
bias=False,
groups=self.num_heads,
)
if landmark_pooling is not None:
self.landmark_pooling = landmark_pooling
else:
self.landmark_pooling = AvgPool(n=self.num_landmarks)
# Optional lower triangular masks for causal attention
self.causal_mask_1: Optional[torch.Tensor] = None
self.causal_mask_2: Optional[torch.Tensor] = None
self.causal_mask_3: Optional[torch.Tensor] = None
# This attention does not support attention masks
self.supports_attention_mask = False
self.supports_key_padding_mask = True
[docs] def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
):
r"""
key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or
(batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will
be ignored. An additive mask is expected, meaning float values using "-inf" to mask values
"""
batched_dim = k.size(0)
seq_len = k.size(-2)
tt = {"dtype": q.dtype, "device": q.device}
if key_padding_mask is not None:
if key_padding_mask.dtype == torch.bool:
logger.warning(
"Bool mask found, but an additive mask is expected. Converting but this is slow"
)
key_padding_mask = bool_mask_to_additive(key_padding_mask)
if key_padding_mask.ndim == 2:
key_padding_mask = reshape_key_padding_mask(
key_padding_mask, batched_dim
)
zeros = torch.zeros_like(key_padding_mask)
ones = torch.ones_like(key_padding_mask)
is_masked = torch.isinf(-key_padding_mask)
# _mask takes 1 if the token is not padded, otherwise 0.
_mask = torch.where(is_masked, zeros, ones)
_mask = _mask.transpose(2, 1)
assert _mask.shape == (batched_dim, q.shape[1], 1)
# Mask q and k before pooling
# https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31
q = q * _mask
k = k * _mask
assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
)
if self.num_landmarks >= seq_len:
mask: Optional[torch.Tensor] = None
if self.causal:
mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt)
if key_padding_mask is not None:
mask = key_padding_mask if mask is None else mask + key_padding_mask
x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)
else:
q_landmarks = self.landmark_pooling(q)
k_landmarks = self.landmark_pooling(k)
if self.causal and (
self.causal_mask_1 is None
or (batched_dim, seq_len, self.num_landmarks)
!= self.causal_mask_1.size()
):
self.causal_mask_1 = self._triu_mask(
batched_dim, seq_len, self.num_landmarks, **tt
)
self.causal_mask_2 = self._triu_mask(
batched_dim, self.num_landmarks, self.num_landmarks, **tt
)
self.causal_mask_3 = self._triu_mask(
batched_dim, self.num_landmarks, seq_len, **tt
)
mask_3: Optional[torch.Tensor] = self.causal_mask_3
if key_padding_mask is not None:
mask_3 = (
key_padding_mask if mask_3 is None else mask_3 + key_padding_mask
)
kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None)
kernel_2 = scaled_query_key_softmax(
q=q_landmarks, k=k_landmarks, att_mask=None
)
kernel_3 = scaled_dot_product_attention(
q=q_landmarks, k=k, v=v, att_mask=mask_3
)
kernel_2_inv = (
iterative_pinv(
kernel_2, self.inv_iterations, self.pinverse_original_init
)
if self.use_razavi_pinverse
else torch.linalg.pinv(kernel_2)
)
x = torch.matmul(
torch.matmul(
kernel_1,
kernel_2_inv,
),
kernel_3,
)
if self.skip_connection:
# Assumption here is that v is 3D.
v_conv = self.skip_connection(
v.reshape(-1, self.num_heads, v.size(-2), v.size(-1))
)
x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1))
x = self.attn_drop(x)
return x
def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor:
device = kwargs["device"]
dtype = kwargs["dtype"]
return torch.triu(
torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"),
diagonal=1,
).expand(
dim_1, -1, -1
) # micro optim, save memory on the batch dimension