# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import asdict, dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.nn.init import constant_
from xformers.components.attention import Attention
from xformers.components.input_projection import InputProjection, InputProjectionConfig
from xformers.components.positional_embedding import RotaryEmbedding
logger = logging.getLogger("xformers")
@dataclass
class MultiHeadDispatchConfig:
dim_model: int
num_heads: int
attention: Attention
bias: bool
residual_dropout: float
dim_key: Optional[int]
dim_value: Optional[int]
in_proj_container: Optional[InputProjection]
use_separate_proj_weight: Optional[bool]
use_rotary_embeddings: Optional[bool]
out_proj: Optional[nn.Module]
def __getitem__(self, item):
return getattr(self, item)
# Move head forward and fold into batch dim. dimensions become (B * nh, S, hs)
def _fold_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int):
return t.view(B, S, H, Hs).transpose(1, 2).flatten(start_dim=0, end_dim=1)
# Move head forward and fold into batch dim. dimensions become (B, nh, S, hs)
def _split_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int):
return t.view(B, S, H, Hs).transpose(1, 2)
[docs]class MultiHeadDispatch(nn.Module):
"""
A multi-head masked self-attention dispatch mechanism, with a projection at the end,
following the architecture proposed in `Attention is all you need`_, Vaswani et al.
The actual attention mechanism can vary, as well as the projections.
This can be used to wrap the proposed attention mechanisms and make them multi-head aware,
but it is optional.
Args:
dim_model: The model/embedding dimension
num_heads: The number of heads being used
attention: The attention mechanism (needs to be registered to the xformers library)
bias: Whether to use bias for the projections : (Q, K, V, Output)
residual_dropout: Amount of dropout on the residual path
use_separate_proj_weight: Use different weights for the Q, K, V projections
dim_key: Optionally use a different dimension for the key
dim_value: Optionally use a different dimension for the value
in_proj_container: Optionally provide the input projection module
use_rotary_embeddings: Use rotary embeddings
out_proj: Optionally provide the output projection module
.. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
"""
def __init__(
self,
dim_model: int,
num_heads: int,
attention: Attention,
bias: Tuple[bool, bool, bool, bool] = (True, True, True, True),
residual_dropout: float = 0.0,
use_separate_proj_weight: bool = True,
dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
in_proj_container: Optional[InputProjection] = None,
use_rotary_embeddings: Optional[bool] = False,
out_proj: Optional[nn.Module] = None,
*args,
**kwargs,
):
super().__init__()
if isinstance(bias, bool):
logger.warning(
"Single bias value provided for the MHA projections."
+ f" Assuming the same parameter ({bias}) is to be used everywhere"
)
bias = (bias, bias, bias, bias)
assert (
dim_model % num_heads == 0
) # static preset for now, each head works on 1/d the embeddings, could be relaxed
assert num_heads > 0
# Popular default is that all latent dimensions are the same
dim_key, dim_value = map(lambda x: x if x else dim_model, (dim_key, dim_value))
self.num_heads = num_heads
self.dim_key_head = dim_key // num_heads
self.dim_value_head = dim_value // num_heads
self.dim_model = dim_model
self.attention = attention
# key, query, value projections for all heads
# critical options are
# - are we sharing weights ?
# - are we adding biases ?
if attention.requires_input_projection:
self.in_proj_container = (
in_proj_container
if in_proj_container is not None
else InputProjection(
query_proj_params=InputProjectionConfig(
dim_model, dim_key, bias=bias[0]
),
key_proj_params=InputProjectionConfig(
dim_model, dim_key, bias=bias[1]
),
value_proj_params=InputProjectionConfig(
dim_model, dim_value, bias=bias[2]
),
use_separate_proj_weight=use_separate_proj_weight,
)
)
# Optional rotary embeddings
self.rotary_embeddings = (
RotaryEmbedding(self.dim_key_head) if use_rotary_embeddings else None
)
# Regularization
self.resid_drop = nn.Dropout(residual_dropout, inplace=False)
# Output projection
self.proj = (
out_proj if out_proj else nn.Linear(dim_model, dim_model, bias=bias[3])
)
if isinstance(self.proj, nn.Linear) and self.proj.bias is not None:
constant_(self.proj.bias, 0.0)
[docs] def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
att_mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Expected input dimensions are [batch size, sequence length, embed dim]
Output dimensions are [batch size, sequence length, embed dim]
"""
if key is None:
key = query
if value is None:
value = query
if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]:
max_batch = max((query.shape[0], key.shape[0], value.shape[0]))
query, key, value = map(
lambda x: x.expand(max_batch, -1, -1), [query, key, value]
)
B, S_Q, _ = query.size() # Batch x Sequence x Embedding (latent)
_, S_K, _ = key.size() # K, Q's sequence length could differ
# Catch different query and key length but a causal attention
if S_Q != S_K:
assert (
not self.attention.requires_same_k_q_dimensions
), "This attention mechanism requires query and key to have the same sequence (context) lengths"
if hasattr(self.attention, "causal"):
assert not self.attention.causal, (
"Causal attention is not supported when key and query have different sequence lengths.\n"
+ "In that case causality is ill-determined. Please pad your sequences accordingly"
)
kw_mask_args = {}
if att_mask is not None:
assert (
self.attention.supports_attention_mask
), "This attention does not support attention masks"
kw_mask_args["att_mask"] = att_mask
if key_padding_mask is not None:
assert (
self.attention.supports_key_padding_mask
), "This attention does not support key padding masks"
kw_mask_args["key_padding_mask"] = key_padding_mask
if self.attention.requires_skip_multi_head:
return self.attention(query, key, value, **kw_mask_args)
# Calculate query, key, values for all heads in batch
if self.attention.requires_input_projection:
q, k, v = self.in_proj_container(query=query, key=key, value=value)
else:
k, q, v = key, query, value
# Check the dimensions properly
def check(t, name):
assert (
t.shape[2] % self.num_heads == 0
), f"the {name} embeddings need to be divisible by the number of heads"
check(q, "projected query")
check(v, "projected value")
check(k, "projected key")
# Optional: rotary embedding, add relative positioning information
if self.rotary_embeddings:
# rotary requires the head dimension
q = _split_heads(q, B, S_Q, self.num_heads, self.dim_key_head)
k = _split_heads(k, B, S_K, self.num_heads, self.dim_key_head)
v = _split_heads(v, B, S_K, self.num_heads, self.dim_value_head)
q, k = self.rotary_embeddings(q=q, k=k)
if not self.attention.requires_head_dimension:
q, k, v = q.flatten(0, 1), k.flatten(0, 1), v.flatten(0, 1)
else:
# Reshape k/q/v to either expose the heads, or fold the head dimension into the batch
reshape_fn = (
_split_heads if self.attention.requires_head_dimension else _fold_heads
)
q = reshape_fn(q, B, S_Q, self.num_heads, self.dim_key_head)
k = reshape_fn(k, B, S_K, self.num_heads, self.dim_key_head)
v = reshape_fn(v, B, S_K, self.num_heads, self.dim_value_head)
# Self-attend
y = self.attention(q, k, v, **kw_mask_args)
# Re-assemble all head outputs side by side
y = (
y.view(B, self.num_heads, S_Q, self.dim_value_head)
.transpose(1, 2)
.flatten(start_dim=2, end_dim=3)
)
# Output projection, dropout and good to go
y = self.resid_drop(self.proj(y))
# Return the same sequence size as the input
return y
[docs] @classmethod
def from_config(cls, config: MultiHeadDispatchConfig):
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)