Shortcuts

Using BlockSparseAttention

BlockSparse attention uses Triton to limit the attention computations to some tiles, which you define at construction time. A simple example is that of a causal attention: just compute the lower triangular tiles! The tile size can be changed, the minimum being 16 coefficients on one dimension.

If you already have a per-coefficient pattern in mind and this is not a perfect match with a block pattern, this is probably fine, BlockSparse is fast enough so that dropping some of the computations after the fact with a fine-grained mask is still probably better than dense computations. We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention.

Please note that for now blocksparse attention requires the sequence length to be a power of two.

Let’s look at an example:

import torch

from xformers.components import MultiHeadDispatch
from xformers.components.attention import BlockSparseAttention

BATCH = 2
HEADS = 8
SEQ = 2048
EMB = 1024
BLOCK_SIZE = 32
DROPOUT = 0.1
dtype = torch.float16

# Let's try out a causal mask, but really it could be anything "block sparse enough"
causal_mask = torch.tril(torch.ones((SEQ, SEQ), device=torch.device("cuda"), dtype=dtype))

blocks = SEQ // BLOCK_SIZE
causal_layout = torch.tril(torch.ones([HEADS, blocks, blocks]))

# Let's build our blocksparse attention. Please note that the layout can be
# [SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] or  [HEADS, SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE]
# so that _you can pass a different layout per head_
attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dropout=DROPOUT, num_heads=HEADS)

# Out of commodity, let's build our multihead attention now
# "multi_head" will be responsible for the forward
multi_head = (
    MultiHeadDispatch(
        seq_len=SEQ,
        dim_model=EMB,
        residual_dropout=DROPOUT,
        num_heads=HEADS,
        attention=attention,
    )
    .cuda()
    .half()
)

# Now FW some random data
# Note that passing a per-coefficient mask makes it possible to remove extra coefficients,
# which where required by the blockification
query = torch.randn((BATCH, SEQ, EMB), requires_grad=True, device=torch.device("cuda"), dtype=dtype)

# Self attention in this particular example, no limitations really
att_val = multi_head(query=query, key=query, value=query, att_mask=causal_mask)


#########################################
# Bonus: compare the memory use vs dense:
def mem_use(fn, kwargs, title):
    # bookeeping
    import time

    start = time.time()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # actually run the function
    fn(**kwargs)
    torch.cuda.synchronize()
    stop = time.time()

    # now report
    max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
    print(f"{title} - Peak memory use: {max_memory}MB - {round((stop-start)*1e6)/1e3}ms")


pytorch_multihead = torch.nn.MultiheadAttention(
    EMB, HEADS, batch_first=True, device=torch.device("cuda"), dtype=torch.float16
)

mem_use(multi_head, {"query": query, "key": query, "value": query, "att_mask": causal_mask}, "Blocksparse")
mem_use(pytorch_multihead, {"query": query, "key": query, "value": query, "attn_mask": causal_mask}, "PyTorch")

On a V100, with PyTorch 1.9, Triton 1.1 and xFormers 0.0.2 this reports something along the lines of:

Blocksparse - Peak memory use: 151MB - 6.619ms
PyTorch - Peak memory use: 393MB - 6.837ms

Note that the pattern here is not that sparse (half of the matrix is empty), the more sparse it gets the more biased the result will get towards BlockSparseAttention.