Source code for torchray.attribution.deconvnet

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

r"""
This module implements the *deconvolution*  method of [DECONV]_ for visualizing
deep networks. The simplest interface is given by the :func:`deconvnet`
function:

.. literalinclude:: ../examples/deconvnet.py
    :language: python
    :linenos:

Alternatively, it is possible to run the method "manually". DeConvNet is a
backpropagation method, and thus works by changing the definition of the
backward functions of some layers. The modified ReLU is implemented by class
:class:`DeConvNetReLU`; however, this is rarely used directly; instead, one
uses the :class:`DeConvNetContext` context instead, as follows:

.. literalinclude:: ../examples/deconvnet_manual.py
    :language: python
    :linenos:

See also :ref:`Backprogation methods <backpropagation>` for further examples
and discussion.

Theory
~~~~~~

The only change is a modified definition of the backward ReLU function:

.. math::
    \operatorname{ReLU}^*(x,p) =
    \begin{cases}
    p, & \mathrm{if}~ p > 0,\\
    0, & \mathrm{otherwise} \\
    \end{cases}

Warning:

    DeConvNets are defined for "standard" networks that use ReLU operations.
    Further modifications may be required for more complex or new networks
    that use other type of non-linearities.

References:

    .. [DECONV] Zeiler and Fergus,
                *Visualizing and Understanding Convolutional Networks*,
                ECCV 2014,
                `<https://doi.org/10.1007/978-3-319-10590-1_53>`__.
"""

__all__ = ["DeConvNetContext", "deconvnet"]

import torch

from .common import ReLUContext, saliency


class DeConvNetReLU(torch.autograd.Function):
    """DeConvNet ReLU autograd function.

    This is an autograd function that redefines the ``relu`` function
    to match the DeConvNet ReLU definition.
    """

    @staticmethod
    def forward(ctx, input):
        """DeConvNet ReLU forward function."""
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """DeConvNet ReLU backward function."""
        return grad_output.clamp(min=0)


[docs]class DeConvNetContext(ReLUContext): """DeConvNet context. This context modifies the computation of gradient to match the DeConvNet definition. See :mod:`torchray.attribution.deconvnet` for how to use it. """ def __init__(self): super(DeConvNetContext, self).__init__(DeConvNetReLU)
[docs]def deconvnet(*args, context_builder=DeConvNetContext, **kwargs): """DeConvNet method. The function takes the same arguments as :func:`.common.saliency`, with the defaults required to apply the DeConvNet method, and supports the same arguments and return values. """ return saliency(*args, context_builder=context_builder, **kwargs)