Using the Reversible block¶
Intro¶
This block applies to residual paths, and was first proposed by Gomez et al (1). Its application in the Transformer (3) context was first proposed in the Reformer (2) paper, and is largely unrelated to the other proposals from this paper (LSH and chunked MLP processing).
We use and very lightly adapt the implementation by Robin Bruegger and some blocks from LucidRains.
A reversible layer requires two inputs (x1, x2) and produces two outputs (y1, y2) via two functions F and G, following the relations
y1 = x1 + F(x2)
y2 = x2 + G(y1)
In turn, this means that (x1, x2) can be recovered from (y1, y2) (see 1 for details)
x2 = y2 - G(y1) # Note that another FW-like pass is needed
x1 = y1 - F(x2)
The effect is comparable to activation checkpointing, in that it opens up for a tradeoff in between GPU memory and compute. One benefit is that no extra wrap is needed, all the residual paths can be naturally checkpointed. In a distributed setting, freeing up GPU memory can help using less GPUs, and the saved communication cost can more than make up for the extra compute.
Moreover, if your model is made of a stack of reversible blocks, then the memory requirement does not increase with the number of blocks.
Transformer¶
Considering the multi-head attention and feedforward blocks (including the residual paths), one can set F as MHA (+ layer norm) and G as Feedforward (+ layer norm) and get to something very close (but not exactly the same) to the original Transformer formulation from [Vaswani et al.][3], as follows
y1 = x1 + MHA(x2)
y2 = x2 + Feedforward(y1)
A difference is that the residual path in the Feedforward deals with the original input, and not the MHA output, but in practice if dim(x1) == dim(x2) == dim(model), the accuracy should not be affected, as verified in 2 and in xFormers.
In practice¶
This repository exposes two main helpers in xformers.components.reversible: ReversibleBlock and ReversibleSequence. ReversibleBlock will take f and g as defined above, and ReversibleSequence can combine them sequentially, similarly to torch.nn.ModuleList.
class ReversibleBlock(nn.Module):
def __init__(self, f: nn.Module, g: nn.Module):
...
def forward(self, x: torch.Tensor, f_args={}, g_args={}):
...
class ReversibleSequence(nn.Module):
def __init__(self, blocks: nn.ModuleList):
...
def forward(self, x, arg_route=(True, False), **kwargs):
"""
arg_route: whether to route the kwargs to f and g
"""
...
Reversible layers are also exposed as a boolean option in when building complete xFormers (which is optional), as defined in xformers.factory.model_factory. Please note that the reversible layer is not yet compatible with the use of multiple forward passes and DDP.
class xFormerStackConfig:
block_config: Union[xFormerEncoderConfig, xFormerDecoderConfig]
num_layers: int
reversible: bool # the sequence of layers becomes reversible