Source code for torchray.attribution.excitation_backprop

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

r"""
This module provides an implementation of the *excitation backpropagation*
method of [EBP]_  for saliency visualization. It is a backpropagation method,
and thus works by changing the definition of the backward functions of some
layers.

In simple cases, the :func:`excitation_backprop` function can be used to obtain
the required visualization, as in the following example:

.. literalinclude:: ../examples/excitation_backprop.py
   :language: python
   :linenos:


Alternatively, you can explicitly use the :class:`ExcitationBackpropContext`,
as follows:

.. literalinclude:: ../examples/excitation_backprop_manual.py
   :language: python
   :linenos:

See also :ref:`backprop` for further examples and discussion.

Contrastive variant
~~~~~~~~~~~~~~~~~~~

The contrastive variant of excitation backprop passes the data twice through
the network. The first pass is used to obtain "contrast" activations at some
intermediate layer ``contrast_layer``. The latter is obtained by flipping the
sign of the last classification layer ``classifier_layer``. The visualization
is then obtained at some earlier ``input_layer``. The function
:func:`contrastive_excitation_backprop` can be used to compute this saliency:

.. literalinclude:: ../examples/contrastive_excitation_backprop.py
   :language: python
   :linenos:

This can also be done "manually", as follows:

.. literalinclude:: ../examples/contrastive_excitation_backprop_manual.py
   :language: python
   :linenos:

Theory
~~~~~~

Excitation backprop modifies the backward version of all linear layers in the
network. For a simple 1D case, let the forward layer be given by:

.. math::
     y_i = \sum_{j=1}^N w_{ij} x_j

where :math:`x \in \mathbb{R}^B` is the input and
:math:`w \in\mathbb{R}^{M \times N}` the weight matrix. On the way back, if
:math:`p \in \mathbb{R}^{N}` is the output pseudo-gradient, the input
pseudo-gradient :math:`p'` is given by:

.. math::
    p'_j =
    \sum_{i=1}^N
        \frac
        {w^+_{ij} x_j}
        {\sum_{k=1}^ Nw^+_{ik} x_k} p_i
    \quad\text{where}\quad
    w^+_{ij} = \max\{0, w_{ij}\}
    :label: linear-back

Note that [EBP]_ assumes that the input activations :math:`x` are always
non-negative. This is often true, as linear layer are preceded by non-negative
activation functions such as ReLU, so there is nothing to be done.

Note also that we can rearrange :eq:`linear-back` as

.. math::
    p'_j =
    x_j
    \sum_{i=1}^N
        w^+_{ij}
        \hat p_i,
    \quad\text{where}\quad
    \hat p_i =
        \frac
        {p_i}
        {\sum_{k=1}^ Nw^+_{ik} x_k}.


Here :math:`\hat p` is a normalized version of the output pseudo-gradient and
the summation is the operation performed by the standard backward function for
the linear layer given as input :math:`\hat p`.

All linear layers, including convolution and deconvolution layers as well as
average pooling layers, can be processed in this manner. In general, let
:math:`y = f(x,w)` be a linear layer. In order to compute :eq:`linear-back`, we
can expand it as:

.. math::
    p' = x \odot f^*(x, w^+, p \oslash f(x, w^+))

where :math:`f^*(x, w, p)` is the standard backward function (vector-Jacobian
product) for the linear layer :math:`f` and :math:`\odot` and :math:`\oslash`
denote element-wise multiplication and division, respectively.

The **contrastive variant** of excitation backprop is similar, but uses the
idea of contrasting the excitation for one class with the ones of all the
others.

In order to obtain the "contrast" signal, excitation backprop is run as before
except for the last *classification* linear layer. In order to backpropagate
the excitation for this layer only, the weight parameter :math:`w` is replaced
with its opposite :math:`-w`. Then, the excitations are backpropagated to an
intermediate *contrast* linear layer as normal.

Once the contrast activations have been obtained, excitation backprop is run
again, this time with the "default" weights even for the last linear layer.
However, during backpropagation, when the contrast layer is reached again, the
contrast is subtracted from the excitations. Then, the excitations are
propagated backward as usual.

References:

    .. [EBP] Jianming Zhang, Zhe Lin, Jonathan Brandt, Xiaohui Shen, Stan
             Sclaroff, *Top-down Neural Attention by Excitation Backprop*,
             ECCV 2016, `<https://arxiv.org/abs/1608.00507>`__.
"""

__all__ = [
    "contrastive_excitation_backprop",
    "eltwise_sum",
    "excitation_backprop",
    "ExcitationBackpropContext",
]

import torch
import torch.nn as nn

from torch.autograd import Function
from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock

from .common import Patch, Probe
from .common import attach_debug_probes, get_backward_gradient, get_module
from .common import saliency, resize_saliency


class EltwiseSumFunction(Function):
    """
    Implementation of a skip connection (i.e., element-wise sum function)
    as a :class:`torch.autograd.Function`. This is necessary for patching
    the skip connection as a :class:`torch.nn.Module` object (i.e.,
    :class:`EltwiseSum`).
    """
    @staticmethod
    def forward(ctx, *inputs):
        ctx.save_for_backward(*inputs)
        output = inputs[0]
        for i in range(1, len(inputs)):
            output = output + inputs[i]
        return output

    @staticmethod
    def backward(ctx, grad_output):
        inputs = ctx.saved_tensors
        return (grad_output,) * len(inputs)


# Alias for :class:`EltwiseSumFunction`.
eltwise_sum = EltwiseSumFunction.apply


class EltwiseSum(nn.Module):
    """
    Implemention of skip connection (i.e., element-wise sum) as a
    :class:`torch.nn.Module` (this is necessary to be able to patch the
    skip connection).
    """

    def forward(self, *inputs):
        return eltwise_sum(*inputs)


def update_resnet(model, debug=False):
    """
    Update a ResNet model to use :class:`EltwiseSum` for the skip
    connection.

    Args:
        model (:class:`torchvision.models.resnet.ResNet`): ResNet model.
        debug (bool): If True, print debug statements.

    Returns:
        model (:class:`torchvision.models.resnet.ResNet`): ResNet model
        that uses :class:`EltwiseSum` for the skip connections. The forward
        functions of :class:`torchvision.models.resnet.BasicBlock` and
        :class:`torch.models.resnet.Bottleneck` are modified.
    """
    assert isinstance(model, ResNet)

    def bottleneck_forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.skip(out, identity)
        out = self.relu(out)

        return out

    def basicblock_forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.skip(out, identity)
        out = self.relu(out)

        return out

    for module_name, module in model.named_modules():
        if isinstance(module, Bottleneck):
            module.skip = EltwiseSum()
            module.forward = bottleneck_forward.__get__(module)
        elif isinstance(module, BasicBlock):
            module.skip = EltwiseSum()
            module.forward = basicblock_forward.__get__(module)
        else:
            continue
        if debug:
            print('Adding EltwiseSum as skip connection in {}.'.format(
                module_name))
    return model


[docs]class ExcitationBackpropContext(object): """Context to use Excitation Backpropagation rules.""" @staticmethod def _patch_conv(target_name, enable, debug): """Patch conv functions to use excitation backprop rules. Replicated implementation provided in: https://github.com/jimmie33/Caffe-ExcitationBP/blob/master/src/caffe/layers/conv_layer.cpp Args: target_name (str): name of function to patch (i.e., ``'torch.nn.functional.conv_1d'``). enable (bool): If True, enable excitation backprop rules. debug (bool): If True, print debug statements. Returns: :class:`.common.Patch`: object that patches the function with the name :attr:`target_name` with a new callable that implements the excitation backprop rules for the conv function. """ target, attribute = Patch.resolve(target_name) conv = getattr(target, attribute) def forward(ctx, input, weight, bias=None, *args, **kwargs): ctx.save_for_backward(input, weight, bias) ctx.args = args ctx.kwargs = kwargs if debug: print("EBP " + target_name) return conv(input, weight, bias, *ctx.args, **ctx.kwargs) def backward(ctx, grad_output): def get(i): x = ctx.saved_tensors[i] if x is None: return x = x.detach() x.requires_grad_(ctx.needs_input_grad[i]) return x inputs_ = get(0), get(1), get(2) grad_inputs_ = [None, None, None] subset = [i for i, g in enumerate(ctx.needs_input_grad) if g] inputs_subset = [inputs_[i] for i in subset] with torch.enable_grad(): # EBP changes only the gradients w.r.t. inputs_[0]. We also # compute the gradients w.r.t. the parameters if needed, # although they are trash. Perhaps it would be better to set # them to None? # # The expectation is that the input to the conv layer is # non-negative, which is typical for all but the first layer # due to the ReLUs. Some other implementation makes sure by # clamping inputs_[0]. # 1. set weight W+ to be non-negative and disable bias. if enable: input = inputs_[0] weight = inputs_[1].clamp(min=0) bias = None else: input = inputs_[0] weight = inputs_[1] bias = inputs_[2] # 2. do forward pass. output_ebp = conv( input, weight, bias, *ctx.args, **ctx.kwargs ) # 3. normalize gradient by the output of the forward pass. if enable: grad_output = grad_output / (output_ebp + 1e-20) # 4. do backward pass. _ = torch.autograd.grad( output_ebp, inputs_subset, grad_output, only_inputs=True) for i, j in enumerate(subset): grad_inputs_[j] = _[i] # 5. multiply gradient with the layer's input. if ctx.needs_input_grad[0] and enable: grad_inputs_[0] *= inputs_[0] return (grad_inputs_[0], grad_inputs_[1], grad_inputs_[2], None, None, None, None) autograd_conv = type('EBP_' + attribute, (torch.autograd.Function,), { 'forward': staticmethod(forward), 'backward': staticmethod(backward), }) return Patch(target_name, autograd_conv.apply) @staticmethod def _patch_pool(target_name, enable, debug): """Patch pool functions to use excitation backprop rules. Replicated implementation provided in: https://github.com/jimmie33/Caffe-ExcitationBP/blob/master/src/caffe/layers/pooling_layer.cpp Args: target_name (str): name of function to patch (i.e., ``'torch.nn.functional.avg_pool1d'``). enable (bool): If True, enable excitation backprop rules. debug (bool): If True, print debug statements. Returns: :class:`.common.Patch`: object that patches the function with the name :attr:`target_name` with a new callable that implements the excitation backprop rules for the pool function. """ target, attribute = Patch.resolve(target_name) pool = getattr(target, attribute) def forward(ctx, input, *args, **kwargs): ctx.save_for_backward(input) ctx.args = args ctx.kwargs = kwargs if debug: print('EBP ' + target_name) return pool(input, *ctx.args, **ctx.kwargs) def backward(ctx, grad_output): if not ctx.needs_input_grad[0]: return None, input_ = ctx.saved_tensors[0].detach() input_.requires_grad_(True) with torch.enable_grad(): # 1. forward pass. output_ebp = pool(input_, *ctx.args, **ctx.kwargs) # 2. normalize gradient by the output of the forward pass. if enable: grad_output = grad_output / (output_ebp + 1e-20) # 3. do backward pass. grad_input_ = torch.autograd.grad( output_ebp, input_, grad_output, only_inputs=True)[0] # 4. multiply gradient with layer's input. if enable: grad_input_ *= input_ return grad_input_, None, None, None, None, None, None autograd_pool = type('EBP_' + attribute, (torch.autograd.Function,), { 'forward': staticmethod(forward), 'backward': staticmethod(backward), }) return Patch(target_name, autograd_pool.apply) @staticmethod def _patch_norm(target_name, enable, debug): """ Patch normalization functions (e.g., batch norm) to use excitation backprop rules. Args: target_name (str): name of function to patch (i.e., ``'torch.nn.functional.avg_pool1d'``). enable (bool): If True, enable excitation backprop rules. debug (bool): If True, print debug statements. Returns: :class:`.common.Patch`: object that patches the function with the name :attr:`target_name` with a new callable that implements the excitation backprop rules for the normalization function. """ target, attribute = Patch.resolve(target_name) norm = getattr(target, attribute) def forward(ctx, input, *args, **kwargs): ctx.save_for_backward(input) ctx.args = args ctx.kwargs = kwargs if debug: print('EBP ' + target_name) return norm(input, *ctx.args, **ctx.kwargs) def backward(ctx, grad_output): if enable: return grad_output, None, None, None, None, None, None, None input_ = ctx.saved_tensors[0].detach() input_.requires_grad_(True) with torch.enable_grad(): output_ebp = norm(input_, *ctx.args, **ctx.kwargs) grad_input_ = torch.autograd.grad( output_ebp, input_, grad_output, only_inputs=True)[0] return grad_input_, None, None, None, None, None, None, None autograd_norm = type('EBP_' + attribute, (torch.autograd.Function,), { 'forward': staticmethod(forward), 'backward': staticmethod(backward), }) return Patch(target_name, autograd_norm.apply) def _patch_eltwise_sum(self, target_name, enable, debug): """ Patch element-wise sum function (e.g., skip connection) to use excitation backprop rules. Args: target_name (str): name of function to patch (i.e., ``'torch.nn.functional.avg_pool1d'``). enable (bool): If True, enable excitation backprop rules. debug (bool): If True, print debug statements. Returns: :class:`.common.Patch`: object that patches the function with the name :attr:`target_name` with a new callable that implements the excitation backprop rules for the element-wise sum function. """ target, attribute = Patch.resolve(target_name) eltwise_sum_f = getattr(target, attribute) def forward(ctx, *inputs): ctx.save_for_backward(*inputs) if debug: print("EBP " + target_name) return eltwise_sum_f(*inputs) def backward(ctx, grad_output): inputs = ctx.saved_tensors if not enable: return (grad_output, ) * len(inputs) inputs = [inputs[i].detach() for i in range(len(inputs))] output = eltwise_sum_f(*inputs) grad_outputs = [] for input in inputs: grad_outputs.append(input / output * grad_output) return tuple(grad_outputs) autograd_eltwise_sum = type('EBP_' + attribute, (torch.autograd.Function,), { 'forward': staticmethod(forward), 'backward': staticmethod(backward), }) return Patch(target_name, autograd_eltwise_sum.apply) def __init__(self, enable=True, debug=False): self.enable = enable self.debug = debug self.patches = [] def __enter__(self): # Patch torch functions for convolutional, linear, average pooling, # and adaptive average pooling layers. Also patch eltwise_sum function # (for skip connection in resnet models). self.patches = [ self._patch_conv('torch.nn.functional.conv1d', self.enable, self.debug), self._patch_conv('torch.nn.functional.conv2d', self.enable, self.debug), self._patch_conv('torch.nn.functional.conv3d', self.enable, self.debug), self._patch_conv('torch.nn.functional.linear', self.enable, self.debug), self._patch_pool('torch.nn.functional.avg_pool1d', self.enable, self.debug), self._patch_pool('torch.nn.functional.avg_pool2d', self.enable, self.debug), self._patch_pool('torch.nn.functional.avg_pool3d', self.enable, self.debug), self._patch_pool('torch.nn.functional.adaptive_avg_pool1d', self.enable, self.debug), self._patch_pool('torch.nn.functional.adaptive_avg_pool2d', self.enable, self.debug), self._patch_pool('torch.nn.functional.adaptive_avg_pool3d', self.enable, self.debug), self._patch_norm('torch.nn.functional.batch_norm', self.enable, self.debug), self._patch_eltwise_sum( 'torchray.attribution.excitation_backprop.eltwise_sum', self.enable, self.debug), ] return self def __exit__(self, type, value, traceback): for patch in self.patches: patch.remove() return False # re-raise any exception
def _get_classifier_layer(model): r"""Get the classifier layer. Args: model (:class:`torch.nn.Module`): a model. Returns: (:class:`torch.nn.Module`, str): tuple of the last layer and its name. """ # Get last layer with weight parameters. last_layer_name = None last_layer = None for parameter_name, _ in list(model.named_parameters())[::-1]: if '.weight' in parameter_name: last_layer_name, _ = parameter_name.split('.weight') last_layer = get_module(model, last_layer_name) # Check that last layer is convolutional or linear. if (isinstance(last_layer, nn.Conv1d) or isinstance(last_layer, nn.Conv2d) or isinstance(last_layer, nn.Conv3d) or isinstance(last_layer, nn.Linear)): break else: last_layer_name = None last_layer = None assert last_layer_name is not None assert last_layer is not None return last_layer, last_layer_name def gradient_to_excitation_backprop_saliency(x): r"""Convert a gradient to an excitation backprop saliency map. The tensor :attr:`x` must have a valid gradient ``x.grad``. The function then computes the excitation backprop saliency map :math:`s` given by: .. math:: s_{n,1,u} = \max(\sum_{0 \leq c < C} dx_{ncu}, 0) where :math:`n` is the instance index, :math:`c` the channel index and :math:`u` the spatial multi-index (usually of dimension 2 for images). Args: x (:class:`torch.Tensor`): activation with gradient. Return: :class:`torch.Tensor`: saliency map. """ return torch.sum(x.grad, 1, keepdim=True) def gradient_to_contrastive_excitation_backprop_saliency(x): r"""Convert a gradient to an contrastive excitation backprop saliency map. The tensor :attr:`x` must have a valid gradient ``x.grad``. The function then computes the excitation backprop saliency map :math:`s` given by: .. math:: s_{n,1,u} = \max(\sum_{0 \leq c < C} dx_{ncu}, 0) where :math:`n` is the instance index, :math:`c` the channel index and :math:`u` the spatial multi-index (usually of dimension 2 for images). Args: x (:class:`torch.Tensor`): activation with gradient. Return: :class:`torch.Tensor`: saliency map. """ return torch.clamp(torch.sum(x.grad, 1, keepdim=True), min=0)
[docs]def excitation_backprop(*args, context_builder=ExcitationBackpropContext, gradient_to_saliency=gradient_to_excitation_backprop_saliency, **kwargs): r"""Excitation backprop. The function takes the same arguments as :func:`.common.saliency`, with the defaults required to apply the Excitation backprop method, and supports the same arguments and return values. """ assert context_builder is ExcitationBackpropContext return saliency( *args, context_builder=context_builder, gradient_to_saliency=gradient_to_saliency, **kwargs )
[docs]def contrastive_excitation_backprop(model, input, target, saliency_layer, contrast_layer, classifier_layer=None, resize=False, resize_mode='bilinear', get_backward_gradient=get_backward_gradient, debug=False): """Contrastive excitation backprop. Args: model (:class:`torch.nn.Module`): a model. input (:class:`torch.Tensor`): input tensor. target (int or :class:`torch.Tensor`): target label(s). saliency_layer (str or :class:`torch.nn.Module`): name of the saliency layer (str) or the layer itself (:class:`torch.nn.Module`) in the model at which to visualize. contrast_layer (str or :class:`torch.nn.Module`): name of the contrast layer (str) or the layer itself (:class:`torch.nn.Module`). classifier_layer (str or :class:`torch.nn.Module`, optional): name of the last classifier layer (str) or the layer itself (:class:`torch.nn.Module`). Defaults to ``None``, in which case the functions tries to automatically identify the last layer. Default: ``None``. resize (bool or tuple, optional): If True resizes the saliency map to the same size as :attr:`input`. It is also possible to pass a (width, height) tuple to specify an arbitrary size. Default: ``False``. resize_mode (str, optional): Specify the resampling mode. Default: ``'bilinear'``. get_backward_gradient (function, optional): function that generates gradient tensor to backpropagate. Default: :func:`.common.get_backward_gradient`. debug (bool, optional): If True, also return :class:`collections.OrderedDict` of :class:`.common.Probe` objects attached to all named modules in the model. Default: ``False``. Returns: :class:`torch.Tensor` or tuple: If :attr:`debug` is False, returns a :class:`torch.Tensor` saliency map at :attr:`saliency_layer`. Otherwise, returns a tuple of a :class:`torch.Tensor` saliency map at :attr:`saliency_layer` and an :class:`collections.OrderedDict` of :class:`Probe` objects for all modules in the model. """ debug_probes = attach_debug_probes(model, debug=debug) # Disable gradients for model parameters. for param in model.parameters(): param.requires_grad_(False) # Set model to eval mode. if model.training: model.eval() saliency_layer = get_module(model, saliency_layer) contrast_layer = get_module(model, contrast_layer) if classifier_layer is None: classifier_layer, _ = _get_classifier_layer(model) classifier_layer = get_module(model, classifier_layer) with ExcitationBackpropContext(): probe_contrast = Probe(contrast_layer, target='output') output = model(input) gradient = get_backward_gradient(output, target) try: # Flip the weights of the last layer. classifier_layer.weight.data.neg_() output.backward(gradient.clone()) # Save negative gradient and prepare to backpropagated contrastive # gradient. probe_contrast.contrast = [probe_contrast.data[0].grad] finally: # Flip back. classifier_layer.weight.data.neg_() # Forward-backward pass to get positive gradient at the contrastive # layer and and backpropagate contrastive gradient to input. probe_saliency = Probe(saliency_layer, target='output') output = model(input) output.backward(gradient) saliency_map = gradient_to_contrastive_excitation_backprop_saliency( probe_saliency.data[0]) probe_saliency.remove() probe_contrast.remove() saliency_map = resize_saliency(input, saliency_map, resize, resize_mode) if debug: return saliency_map, debug_probes else: return saliency_map