Source code for xformers.components.attention.local
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
maybe_sparsify,
register_attention,
sparsify,
)
from xformers.components.attention.attention_patterns import (
causal_1d_pattern,
local_1d_pattern,
)
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class LocalAttentionConfig(AttentionConfig):
causal: Optional[bool] = None
window_size: Optional[int] = None
force_sparsity: Optional[bool] = None
[docs]@register_attention("local", LocalAttentionConfig)
class LocalAttention(Attention):
[docs] def __init__(
self,
dropout: float = 0.0,
causal: bool = False,
window_size: int = 5,
force_sparsity: bool = False,
*args,
**kwargs,
):
r"""
An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_
Args:
dropout (float): the probability of an output to be randomly dropped at training time
causal (bool): apply a causal mask, in that the attention cannot be applied to the future
window_size (int): the overall window size for local attention.
Odd number is expected if the mask is not causal, as the window size will be evenly
distributed on both sides of each query
.. _RoutingTransformer: https://arxiv.org/pdf/2003.05997.pdf
.. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
.. _Longformer: https://arxiv.org/pdf/2004.05150.pdf
"""
super().__init__()
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.causal = causal
self.force_sparsity = force_sparsity
if not self.causal:
assert (
window_size % 2 == 1
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
self.window_size = window_size
self.attention_mask: Optional[torch.Tensor] = None
self.requires_same_k_q_dimensions = True
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
def _get_local_mask(self, shape: torch.Size) -> torch.Tensor:
window_size = self.window_size * 2 + 1 if self.causal else self.window_size
mask = local_1d_pattern(shape[1], window_size)
if self.causal:
mask &= causal_1d_pattern(shape[1])
mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
return mask
[docs] def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*args,
**kwargs,
):
# Local window attention masking
if self.attention_mask is None or self.attention_mask.shape[1] != q.shape[1]:
self.attention_mask = self._get_local_mask(q.shape).to(q.device)
# Take into account the optional user mask
if att_mask is None:
mask = self.attention_mask
else:
if isinstance(att_mask, AttentionMask):
# Needed because & op not defined for SparseCS with AttentionMask
att_mask = att_mask.to_bool()
mask = self.attention_mask & att_mask
return scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop
)