Shortcuts

Source code for xformers.components.attention.scaled_dot_product

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging
from dataclasses import dataclass
from typing import Optional, Union

import torch
from torch import nn

from xformers.components.attention import (
    Attention,
    AttentionConfig,
    AttentionMask,
    register_attention,
)
from xformers.components.attention.core import scaled_dot_product_attention

logger = logging.getLogger("xformers")


@dataclass
class ScaledDotProductConfig(AttentionConfig):
    causal: Optional[bool]
    seq_len: Optional[int]
    to_seq_len: Optional[int]


[docs]@register_attention("scaled_dot_product", ScaledDotProductConfig) class ScaledDotProduct(Attention): r""" Implementing the Scaled Dot-Product attention proposed in `Attention is all you need`_, Vaswani et al. .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5 """ mask: Optional[AttentionMask] def __init__( self, dropout: float = 0.0, causal: bool = False, seq_len: Optional[int] = None, to_seq_len: Optional[int] = None, *args, **kwargs, ): super().__init__() self.attn_drop = nn.Dropout(dropout, inplace=False) self.causal = causal self.seq_len = seq_len if causal and seq_len is not None: self.mask = AttentionMask.make_causal(seq_len, to_seq_len) else: self.mask = None # Properties specific to this attention mechanism self.supports_attention_mask = True self.supports_key_padding_mask = False
[docs] def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None, *args, **kwargs, ) -> torch.Tensor: r""" att_mask A 2D or 3D mask which ignores attention at certain positions. - If the mask is boolean, a value of True will keep the value, while a value of False will mask the value. Key padding masks (dimension: batch x sequence length) and attention masks (dimension: sequence length x sequence length OR batch x sequence length x sequence length) can be combined and passed in here. Method maybe_merge_masks provided in the utils can be used for that merging. - If the mask has the float type, then an additive mask is expected (masked values are -inf) """ # Convenience, create an attention mask if a tensor was passed if att_mask is not None and isinstance(att_mask, torch.Tensor): # By default we don't know of the causality, and a check would be expensive att_mask = ( AttentionMask.from_bool(att_mask) if att_mask.dtype == torch.bool else AttentionMask(att_mask, is_causal=False) ) # Handle a possibly deferred causal mask handling mask = self.mask if self.causal and self.mask is None: mask = AttentionMask.make_causal( seq_len=q.shape[-2], to_seq_len=q.shape[-2], device=q.device, dtype=q.dtype, ) # Merge the optional causal mask and the user-provided mask if mask is not None: mask = mask.to(dtype=q.dtype, device=q.device) att_mask = att_mask + mask if att_mask is not None else mask # Try to handle a case where the sequence is smaller than the mask if ( att_mask is not None and q.shape[-2] == k.shape[-2] and q.shape[-2] < att_mask.shape[1] ): if isinstance(att_mask, AttentionMask): att_mask = att_mask.make_crop(seq_len=q.shape[-2]) else: logger.error( "Mismatching sparse attention mask and sequence length." + " Please pad the inputs or adjust the attention mask" ) raise NotImplementedError # Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S) y = scaled_dot_product_attention( q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop ) return y