Shortcuts

Source code for xformers.components.attention.attention_mask

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from typing import Optional, Type, TypeVar

import torch

Self = TypeVar("Self", bound="AttentionMask")


[docs]class AttentionMask: """ Holds an attention mask, along with a couple of helpers and attributes. .. note: this is an additive mask, meaning that coefficients which should be computed hold the '0.' value, and coefficients which should be skipped hold the '-inf' value. Any other value is possible if the purpose is to bias the attention computation for instance .. note: the attention mask dimensions are expected to be `[batch, to_sequence, from_sequence]`, `[to_sequence, from_sequence]`, or anything broadcastable in between """ def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False): assert additive_mask.is_floating_point(), additive_mask.dtype assert not additive_mask.requires_grad if additive_mask.ndim == 2: additive_mask = additive_mask.unsqueeze(0) self.values = additive_mask self.is_causal = is_causal self.seq_len = additive_mask.shape[1] self.to_seq_len = additive_mask.shape[0]
[docs] def to_bool(self) -> torch.Tensor: """ .. warning: we assume here that True implies that the value should be computed """ return self.values != float("-inf")
[docs] @classmethod def from_bool(cls: Type[Self], x: torch.Tensor) -> Self: """ Create an AttentionMask given a boolean pattern. .. warning: we assume here that True implies that the value should be computed """ assert x.dtype == torch.bool additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device) additive_mask.masked_fill_(x, 0.0) additive_mask.masked_fill_(~x, float("-inf")) return cls(additive_mask)
[docs] @classmethod def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self: """ Create an AttentionMask given a multiplicative attention mask. """ assert not x.dtype == torch.bool additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device) x = x.bool() additive_mask.masked_fill_(x, 0.0) additive_mask.masked_fill_(~x, float("-inf")) return cls(additive_mask)
[docs] @classmethod def make_causal( cls: Type[Self], seq_len: int, to_seq_len: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Self: if not to_seq_len: to_seq_len = seq_len additive_mask = torch.triu( torch.ones(seq_len, to_seq_len, device=device, dtype=dtype) * float("-inf"), diagonal=1, ) return cls(additive_mask=additive_mask, is_causal=True)
[docs] def make_crop( self, seq_len: int, to_seq_len: Optional[int] = None ) -> "AttentionMask": """ Return a cropped attention mask, whose underlying tensor is a view of this one """ if not to_seq_len: to_seq_len = seq_len return AttentionMask( self.values[:, :seq_len, :to_seq_len], is_causal=self.is_causal )
def __repr__(self): return f"AttentionMask - causal {self.is_causal} - mask " + str(self.values) @property def device(self): return self.values.device @property def is_sparse(self): return False @property def ndim(self): return len(self.values.shape) @property def dtype(self): return self.values.dtype @property def shape(self): return self.values.shape def __add__(self, other): return AttentionMask(self.values + other.values, is_causal=False)
[docs] def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ) -> "AttentionMask": assert device is None or isinstance(device, torch.device) assert dtype is None or isinstance(dtype, torch.dtype) assert device is not None or dtype is not None # Noop if we don't need to create another instance if ((device and device == self.device) or not device) and ( (dtype and dtype == self.dtype) or not dtype ): return self return AttentionMask(self.values.to(device=device, dtype=dtype), self.is_causal)