Attribution

Attribution is the problem of determining which part of the input, e.g. an image, is responsible for the value computed by a predictor such as a neural network.

Formally, let \(\mathbf{x}\) be the input to a convolutional neural network, e.g., a \(N \times C \times H \times W\) real tensor. The neural network is a function \(\Phi\) mapping \(\mathbf{x}\) to a scalar output \(z \in \mathbb{R}\). Thus the goal is to find which of the elements of \(\mathbf{x}\) are “most responsible” for the outcome \(z\).

Some attribution methods are “black box” approaches, in the sense that they ignore the nature of the function \(\Phi\) (however, most assume that it is at least possible to compute the gradient of \(\Phi\) efficiently). Most attribution methods, however, are “white box” approaches, in the sense that they exploit the knowledge of the structure of \(\Phi\).

Backpropagation methods are “white box” visualization approaches that build on backpropagation, thus leveraging the functionality already implemented in standard deep learning packages toolboxes such as PyTorch.

Perturbation methods are “black box” visualization approaches that generate attribution visualizations by perturbing the input and observing the changes in a model’s output.

TorchRay implements the following methods:

Footnotes

1(1,2)

The gradient and extremal_perturbation methods actually straddle the boundaries between white and black box methods, as they only require the ability to compute the gradient of the predictor, which does not necessarily require to know the predictor internals. However, in TorchRay both are implemented using backpropagation.

Backpropagation methods

Backpropagation methods work by tweaking the backpropagation algorithm that, on its own, computes the gradient of tensor functions. Formally, a neural network \(\Phi\) is a collection \(\Phi_1,\dots,\Phi_n\) of \(n\) layers. Each layer is in itself a “smaller” function inputting and outputting tensors, called activations (for simplicity, we call activations the network input and parameter tensors as well). Layers are interconnected in a Directed Acyclic Graph (DAG). The DAG is bipartite with some nodes representing the activation tensors and the other nodes representing the layers, with interconnections between layers and input/output tensors in the obvious way. The DAG sources are the network’s input and parameter tensors and the DAG sinks are the network’s output tensors.

The main goal of a deep neural network toolbox such as PyTorch is to evaluate the function \(\Phi\) implemented by the DAG as well as its gradients with respect to various tensors (usually the model parameters). The calculation of the gradients, which uses backpropagation, associates to the forward DAG a backward DAG, obtained as follows:

  • Activation tensors \(\mathbf{x}_j\) become gradient tensors \(d\mathbf{x}_j\) (preserving their shape).

  • Forward layers \(\Phi_i\) become backward layers \(\Phi_i^*\).

  • All arrows are reversed.

  • Additional arrows connecting the activation tensors \(\mathbf{x}_i\) as inputs to the corresponding backward function \(\Phi_i^*\) are added as well.

Backpropagation methods modify the backward graph in order to generate a visualization of the network forward pass. Additionally, inputs as well as intermediate activations can be inspected to obtain different visualizations. These two concepts are explained next.

Changing the backward propagation rules

Changing the backward propagation rules amounts to redefining the functions \(\Phi_i^*\). After doing so, the “gradients” computed by backpropagation change their meaning into something useful for visualization. We call these modified gradients pseudo-gradients.

TorchRay provides a number of context managers that enable patching PyTorch functions on the fly in order to change the backward propagation rules for a segment of code. For example, let x be an input tensor and model a deep classification network. Furthermore, let category_id be the index of the class for which we want to attribute input regions. The following code uses guided_backprop to compute and store the pseudo gradient in x.grad.

from torchray.attribution.guided_backprop import GuidedBackpropContext

x.requires_grad_(True)

with GuidedBackpropContext():
      y = model(x)
      z = y[0, category_id]
      z.backward()

At this point, x.grad contains the “guided gradient” computed by this method. This gradient is usually flattened along the channel dimension to produce a saliency map for visualization:

from torchray.attribution.common import gradient_to_saliency

saliency = gradient_to_saliency(x)

TorchRay contains also some wrapper code, such as guided_backprop.guided_backprop(), that combine these steps in a way that would work for common networks.

Probing intermediate activations and gradients

Most visualization methods are based on inspecting the activations when the network is evaluated and the pseudo-gradients during backpropagation. This is generally easy for input tensors. For intermediate tensors, when using PyTorch functional interface, this is also easy: simply use retain_grad_(True) in order to retain the gradient of an intermediate tensor:

from torch.nn.functional import relu, conv2d
from torchray.attribution import GuidedBackpropContext

with GuidedBackpropContext():
      y = conv2d(x, weight)
      y.requires_grad_(True)
      y.retain_grad_(True)
      z = relu(y)[0, class_index]
      z.backward()

# Now y and y.grad contain the activation and guided gradient,
# respectively.

However, in PyTorch most network components are implemented as torch.nn.Module objects. In this case, is not obvious how to access a specific layer’s information. In order to simplify this process, the library provides the Probe class:

from torch.nn.functional import relu, conv2d
from torchray.attribution.guided_backprop import GuidedBackpropContext
import torchray.attribution.Probe

# Attach a probe to the last conv layer.
probe = Probe(alexnet.features[11])

with GuidedBackpropContext():
      y = alexnet(x)
      z = y[0, class_index]
      z.backward()

# Now probe.data[0] and probe.data[0].grad contain
# the activations and guided gradients.

The probe automatically applies torch.Tensor.requires_grad_() and torch.Tensor.retain_grad_() as needed. You can use probe.remove() to remove the probe from the network once you are done.

Limitations

Except for the gradient method, backpropagation methods require modifying the backward function of each layer. TorchRay implements the rules necessary to do so as originally defined by each authors’ method. However, as new neural network layers are introduced, it is possible that the default behavior, which is to not change backpropagation, may be inappropriate or suboptimal for them.

Perturbation methods

Perturbation methods work by changing the input to the neural network in a controlled manner, observing the outcome on the output generated by the network. Attribution can be achieved by occluding (setting to zero) specific parts of the image and checking whether this has a strong effect on the output. This can be thought of as a form of sensitivity analysis which is still specific to a given input, but is not differential as for the gradient method.

DeConvNet

This module implements the deconvolution method of [DECONV] for visualizing deep networks. The simplest interface is given by the deconvnet() function:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
from torchray.attribution.deconvnet import deconvnet
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# DeConvNet method.
saliency = deconvnet(model, x, category_id)

# Plots.
plot_example(x, saliency, 'deconvnet', category_id)

Alternatively, it is possible to run the method “manually”. DeConvNet is a backpropagation method, and thus works by changing the definition of the backward functions of some layers. The modified ReLU is implemented by class DeConvNetReLU; however, this is rarely used directly; instead, one uses the DeConvNetContext context instead, as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from torchray.attribution.common import gradient_to_saliency
from torchray.attribution.deconvnet import DeConvNetContext
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# DeConvNet method.
x.requires_grad_(True)

with DeConvNetContext():
    y = model(x)
    z = y[0, category_id]
    z.backward()

saliency = gradient_to_saliency(x)

# Plots.
plot_example(x, saliency, 'deconvnet', category_id)

See also Backprogation methods for further examples and discussion.

Theory

The only change is a modified definition of the backward ReLU function:

\[\begin{split}\operatorname{ReLU}^*(x,p) = \begin{cases} p, & \mathrm{if}~ p > 0,\\ 0, & \mathrm{otherwise} \\ \end{cases}\end{split}\]

Warning

DeConvNets are defined for “standard” networks that use ReLU operations. Further modifications may be required for more complex or new networks that use other type of non-linearities.

References

DECONV

Zeiler and Fergus, Visualizing and Understanding Convolutional Networks, ECCV 2014, https://doi.org/10.1007/978-3-319-10590-1_53.

class torchray.attribution.deconvnet.DeConvNetContext[source]

Bases: torchray.attribution.common.ReLUContext

DeConvNet context.

This context modifies the computation of gradient to match the DeConvNet definition.

See torchray.attribution.deconvnet for how to use it.

torchray.attribution.deconvnet.deconvnet(*args, context_builder=<class 'torchray.attribution.deconvnet.DeConvNetContext'>, **kwargs)[source]

DeConvNet method.

The function takes the same arguments as common.saliency(), with the defaults required to apply the DeConvNet method, and supports the same arguments and return values.

Excitation backprop

This module provides an implementation of the excitation backpropagation method of [EBP] for saliency visualization. It is a backpropagation method, and thus works by changing the definition of the backward functions of some layers.

In simple cases, the excitation_backprop() function can be used to obtain the required visualization, as in the following example:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
from torchray.attribution.excitation_backprop import excitation_backprop
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Contrastive excitation backprop.
saliency = excitation_backprop(
    model,
    x,
    category_id,
    saliency_layer='features.9',
)

# Plots.
plot_example(x, saliency, 'excitation backprop', category_id)

Alternatively, you can explicitly use the ExcitationBackpropContext, as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torchray.attribution.common import Probe, get_module
from torchray.attribution.excitation_backprop import ExcitationBackpropContext
from torchray.attribution.excitation_backprop import gradient_to_excitation_backprop_saliency
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Contrastive excitation backprop.
saliency_layer = get_module(model, 'features.9')
saliency_probe = Probe(saliency_layer, target='output')

with ExcitationBackpropContext():
    y = model(x)
    z = y[0, category_id]
    z.backward()

saliency = gradient_to_excitation_backprop_saliency(saliency_probe.data[0])

saliency_probe.remove()

# Plots.
plot_example(x, saliency, 'excitation backprop', category_id)

See also Attribution for further examples and discussion.

Contrastive variant

The contrastive variant of excitation backprop passes the data twice through the network. The first pass is used to obtain “contrast” activations at some intermediate layer contrast_layer. The latter is obtained by flipping the sign of the last classification layer classifier_layer. The visualization is then obtained at some earlier input_layer. The function contrastive_excitation_backprop() can be used to compute this saliency:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
from torchray.attribution.excitation_backprop import contrastive_excitation_backprop
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Contrastive excitation backprop.
saliency = contrastive_excitation_backprop(
    model,
    x,
    category_id,
    saliency_layer='features.9',
    contrast_layer='features.30',
    classifier_layer='classifier.6',
)

# Plots.
plot_example(x, saliency, 'contrastive excitation backprop', category_id)

This can also be done “manually”, as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from torchray.attribution.common import Probe, get_module
from torchray.attribution.excitation_backprop import ExcitationBackpropContext
from torchray.attribution.excitation_backprop import gradient_to_contrastive_excitation_backprop_saliency
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Contrastive excitation backprop.
input_layer = get_module(model, 'features.9')
contrast_layer = get_module(model, 'features.30')
classifier_layer = get_module(model, 'classifier.6')

input_probe = Probe(input_layer, target='output')
contrast_probe = Probe(contrast_layer, target='output')

with ExcitationBackpropContext():
    y = model(x)
    z = y[0, category_id]
    classifier_layer.weight.data.neg_()
    z.backward()

    classifier_layer.weight.data.neg_()

    contrast_probe.contrast = [contrast_probe.data[0].grad]

    y = model(x)
    z = y[0, category_id]
    z.backward()

saliency = gradient_to_contrastive_excitation_backprop_saliency(input_probe.data[0])

input_probe.remove()
contrast_probe.remove()

# Plots.
plot_example(x, saliency, 'contrastive excitation backprop', category_id)

Theory

Excitation backprop modifies the backward version of all linear layers in the network. For a simple 1D case, let the forward layer be given by:

\[y_i = \sum_{j=1}^N w_{ij} x_j\]

where \(x \in \mathbb{R}^B\) is the input and \(w \in\mathbb{R}^{M \times N}\) the weight matrix. On the way back, if \(p \in \mathbb{R}^{N}\) is the output pseudo-gradient, the input pseudo-gradient \(p'\) is given by:

(1)\[p'_j = \sum_{i=1}^N \frac {w^+_{ij} x_j} {\sum_{k=1}^ Nw^+_{ik} x_k} p_i \quad\text{where}\quad w^+_{ij} = \max\{0, w_{ij}\}\]

Note that [EBP] assumes that the input activations \(x\) are always non-negative. This is often true, as linear layer are preceded by non-negative activation functions such as ReLU, so there is nothing to be done.

Note also that we can rearrange (1) as

\[p'_j = x_j \sum_{i=1}^N w^+_{ij} \hat p_i, \quad\text{where}\quad \hat p_i = \frac {p_i} {\sum_{k=1}^ Nw^+_{ik} x_k}.\]

Here \(\hat p\) is a normalized version of the output pseudo-gradient and the summation is the operation performed by the standard backward function for the linear layer given as input \(\hat p\).

All linear layers, including convolution and deconvolution layers as well as average pooling layers, can be processed in this manner. In general, let \(y = f(x,w)\) be a linear layer. In order to compute (1), we can expand it as:

\[p' = x \odot f^*(x, w^+, p \oslash f(x, w^+))\]

where \(f^*(x, w, p)\) is the standard backward function (vector-Jacobian product) for the linear layer \(f\) and \(\odot\) and \(\oslash\) denote element-wise multiplication and division, respectively.

The contrastive variant of excitation backprop is similar, but uses the idea of contrasting the excitation for one class with the ones of all the others.

In order to obtain the “contrast” signal, excitation backprop is run as before except for the last classification linear layer. In order to backpropagate the excitation for this layer only, the weight parameter \(w\) is replaced with its opposite \(-w\). Then, the excitations are backpropagated to an intermediate contrast linear layer as normal.

Once the contrast activations have been obtained, excitation backprop is run again, this time with the “default” weights even for the last linear layer. However, during backpropagation, when the contrast layer is reached again, the contrast is subtracted from the excitations. Then, the excitations are propagated backward as usual.

References

EBP(1,2)

Jianming Zhang, Zhe Lin, Jonathan Brandt, Xiaohui Shen, Stan Sclaroff, Top-down Neural Attention by Excitation Backprop, ECCV 2016, https://arxiv.org/abs/1608.00507.

torchray.attribution.excitation_backprop.contrastive_excitation_backprop(model, input, target, saliency_layer, contrast_layer, classifier_layer=None, resize=False, resize_mode='bilinear', get_backward_gradient=<function get_backward_gradient>, debug=False)[source]

Contrastive excitation backprop.

Parameters
  • model (torch.nn.Module) – a model.

  • input (torch.Tensor) – input tensor.

  • target (int or torch.Tensor) – target label(s).

  • saliency_layer (str or torch.nn.Module) – name of the saliency layer (str) or the layer itself (torch.nn.Module) in the model at which to visualize.

  • contrast_layer (str or torch.nn.Module) – name of the contrast layer (str) or the layer itself (torch.nn.Module).

  • classifier_layer (str or torch.nn.Module, optional) – name of the last classifier layer (str) or the layer itself (torch.nn.Module). Defaults to None, in which case the functions tries to automatically identify the last layer. Default: None.

  • resize (bool or tuple, optional) – If True resizes the saliency map to the same size as input. It is also possible to pass a (width, height) tuple to specify an arbitrary size. Default: False.

  • resize_mode (str, optional) – Specify the resampling mode. Default: 'bilinear'.

  • get_backward_gradient (function, optional) – function that generates gradient tensor to backpropagate. Default: common.get_backward_gradient().

  • debug (bool, optional) – If True, also return collections.OrderedDict of common.Probe objects attached to all named modules in the model. Default: False.

Returns

If debug is False, returns a torch.Tensor saliency map at saliency_layer. Otherwise, returns a tuple of a torch.Tensor saliency map at saliency_layer and an collections.OrderedDict of Probe objects for all modules in the model.

Return type

torch.Tensor or tuple

torchray.attribution.excitation_backprop.excitation_backprop(*args, context_builder=<class 'torchray.attribution.excitation_backprop.ExcitationBackpropContext'>, gradient_to_saliency=<function gradient_to_excitation_backprop_saliency>, **kwargs)[source]

Excitation backprop.

The function takes the same arguments as common.saliency(), with the defaults required to apply the Excitation backprop method, and supports the same arguments and return values.

class torchray.attribution.excitation_backprop.ExcitationBackpropContext(enable=True, debug=False)[source]

Bases: object

Context to use Excitation Backpropagation rules.

Extremal perturbation

This module provides an implementation of the Extremal Perturbations (EP) method of [EP] for saliency visualization. The interface is given by the extremal_perturbation() function:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward
from torchray.benchmark import get_example_data, plot_example
from torchray.utils import get_device

# Obtain example data.
model, x, category_id_1, category_id_2 = get_example_data()

# Run on GPU if available.
device = get_device()
model.to(device)
x = x.to(device)

# Extremal perturbation backprop.
masks_1, _ = extremal_perturbation(
    model, x, category_id_1,
    reward_func=contrastive_reward,
    debug=True,
    areas=[0.12],
)

masks_2, _ = extremal_perturbation(
    model, x, category_id_2,
    reward_func=contrastive_reward,
    debug=True,
    areas=[0.05],
)

# Plots.
plot_example(x, masks_1, 'extremal perturbation', category_id_1)
plot_example(x, masks_2, 'extremal perturbation', category_id_2)

Extremal perturbations seek to find a region of the input image that maximally excites a certain output or intermediate activation of a neural network.

Perturbation types

The Perturbation class supports the following perturbation types:

Extremal perturbation variants

The extremal_perturbation() function supports the following variants:

  • PRESERVE_VARIANT: Find a mask that makes the activations large.

  • DELETE_VARIANT: Find a mask that makes the activations small.

  • DUAL_VARIANT: Find a mask that makes the activations large and whose complement makes the activations small, rewarding the difference between these two.

References

EP

Ruth C. Fong, Mandela Patrick and Andrea Vedaldi, Understanding Deep Networks via Extremal Perturbations and Smooth Masks, ICCV 2019, http://arxiv.org/.

torchray.attribution.extremal_perturbation.extremal_perturbation(model, input, target, areas=[0.1], perturbation='blur', max_iter=800, num_levels=8, step=7, sigma=21, jitter=True, variant='preserve', print_iter=None, debug=False, reward_func=<function simple_reward>, resize=False, resize_mode='bilinear', smooth=0)[source]

Compute a set of extremal perturbations.

The function takes a model, an input tensor \(x\) of size \(1\times C\times H\times W\), and a target activation channel. It produces as output a \(K\times C\times H\times W\) tensor where \(K\) is the number of specified areas.

Each mask, which has approximately the specified area, is searched in order to maximise the (spatial average of the) activations in channel target. Alternative objectives can be specified via reward_func.

Parameters
  • model (torch.nn.Module) – model.

  • input (torch.Tensor) – input tensor.

  • target (int) – target channel.

  • areas (float or list of floats, optional) – list of target areas for saliency masks. Defaults to [0.1].

  • perturbation (str, optional) – Perturbation types.

  • max_iter (int, optional) – number of iterations for optimizing the masks.

  • num_levels (int, optional) – number of buckets with which to discretize and linearly interpolate the perturbation (see Perturbation). Defaults to 8.

  • step (int, optional) – mask step (see MaskGenerator). Defaults to 7.

  • sigma (float, optional) – mask smoothing (see MaskGenerator). Defaults to 21.

  • jitter (bool, optional) – randomly flip the image horizontally at each iteration. Defaults to True.

  • variant (str, optional) – Extremal perturbation variants. Defaults to PRESERVE_VARIANT.

  • print_iter (int, optional) – frequency with which to print losses. Defaults to None.

  • debug (bool, optional) – If True, generate debug plots.

  • reward_func (function, optional) – function that generates reward tensor to backpropagate.

  • resize (bool, optional) – If True, upsamples the masks the same size as input. It is also possible to specify a pair (width, height) for a different size. Defaults to False.

  • resize_mode (str, optional) – Upsampling method to use. Defaults to 'bilinear'.

  • smooth (float, optional) – Apply Gaussian smoothing to the masks after computing them. Defaults to 0.

Returns

A tuple containing the masks and the energies. The masks are stored as a torch.Tensor of dimension

class torchray.attribution.extremal_perturbation.Perturbation(input, num_levels=8, max_blur=20, type='blur')[source]

Bases: object

Perturbation pyramid.

The class takes as input a tensor input and applies to it perturbation of increasing strenght, storing the resulting pyramid as the class state. The method apply() can then be used to generate an inhomogeneously perturbed image based on a certain perturbation mask.

The pyramid \(y\) is the \(L\times C\times H\times W\) tensor

\[y_{lcvu} = [\operatorname{perturb}(x, \sigma_l)]_{cvu}\]

where \(x\) is the input tensor, \(c\) a channel, \(vu\), the spatial location, \(l\) a perturbation level, and \(\operatorname{perturb}\) is a perturbation operator.

For the blur perturbation (BLUR_PERTURBATION), the perturbation operator amounts to convolution with a Gaussian whose kernel has standard deviation \(\sigma_l = \sigma_{\mathrm{max}} (1 - l/ (L-1))\):

\[\operatorname{perturb}(x, \sigma_l) = g_{\sigma_l} \ast x\]

For the fade perturbation (FADE_PERTURBATION),

\[\operatorname{perturb}(x, \sigma_l) = \sigma_l \cdot x\]

where \(\sigma_l = l / (L-1)\).

Note that in all cases the last pyramid level \(l=L-1\) corresponds to the unperturbed input and the first \(l=0\) to the maximally perturbed input.

Parameters
  • input (torch.Tensor) – A \(1\times C\times H\times W\) input tensor (usually an image).

  • num_levels (int, optional) – Number of pyramid leves. Defaults to 8.

  • type (str, optional) – Perturbation type (Perturbation types).

  • max_blur (float, optional) – \(\sigma_{\mathrm{max}}\) for the Gaussian blur perturbation. Defaults to 20.

pyramid

A \(L\times C\times H\times W\) tensor with \(L\) ():attr:num_levels) increasingly perturbed versions of the input tensor.

Type

torch.Tensor

apply(mask)[source]

Generate a perturbetd tensor from a perturbation mask.

The mask is a tensor \(K\times 1\times H\times W\) with spatial dimensions \(H\times W\) matching the input tensor passed upon instantiation of the class. The output is a \(K\times C\times H\times W\) with \(K\) perturbed versions of the input tensor, one for each mask.

Masks values are in the range 0 to 1, where 1 means that the input tensor is copied as is, and 0 that it is maximally perturbed.

Formally, the output is then given by:

\[z_{kcvu} = y_{m_{k1cu}, c, v, u}\]

where \(k\) index the mask, \(c\) the feature channel, \(vu\) the spatial location, \(y\) is the pyramid tensor, and \(m\) the mask tensor mask.

The mask must be in the range \([0, 1]\). Linear interpolation is used to index the perturbation level dimension of \(y\).

Parameters

mask (torch.Tensor) – A \(K\times 1\times H\times W\) input tensor representing \(K\) masks.

Returns

A \(K\times C\times H\times W\) tensor with \(K\) perturbed versions of the input tensor.

Return type

torch.Tensor

to(dev)[source]

Switch to another device.

Parameters

dev – PyTorch device.

Returns

self.

Return type

Perturbation

torchray.attribution.extremal_perturbation.simple_reward(activation, target, variant)[source]

Simple reward.

For the PRESERVE_VARIANT, the simple reward is given by:

\[z_{k1vu} = y_{n, c, v, u}\]

where \(y\) is the \(K\times C\times H\times W\) activation tensor, \(c\) the target channel, \(k\) the mask index and \(vu\) the spatial indices. \(c\) must be in the range \([0, C-1]\).

For the DELETE_VARIANT, the reward is the opposite.

For the DUAL_VARIANT, it is given by:

\[z_{n1vu} = y_{n, c, v, u} - y_{n + N/2, c, v, u}.\]
Parameters
Returns

reward tensor with the same shape as activation but a single channel.

Return type

torch.Tensor

torchray.attribution.extremal_perturbation.contrastive_reward(activation, target, variant)[source]

Contrastive reward.

For the PRESERVE_VARIANT, the contrastive reward is given by:

\[z_{k1vu} = y_{n, c, v, u} - \max_{c'\not= c} y_{n, c', v, u}\]

The other variants are derived in the same manner as for simple_reward().

Parameters
Returns

reward tensor with the same shape as activation but a single channel.

Return type

torch.Tensor

torchray.attribution.extremal_perturbation.BLUR_PERTURBATION = 'blur'

Blur-type perturbation for Perturbation.

torchray.attribution.extremal_perturbation.FADE_PERTURBATION = 'fade'

Fade-type perturbation for Perturbation.

torchray.attribution.extremal_perturbation.PRESERVE_VARIANT = 'preserve'

Preservation game for extremal_perturbation().

torchray.attribution.extremal_perturbation.DELETE_VARIANT = 'delete'

Deletion game for extremal_perturbation().

torchray.attribution.extremal_perturbation.DUAL_VARIANT = 'dual'

Combined game for extremal_perturbation().

Gradient

This module implements the gradient method of [GRAD] for visualizing a deep network. It is a backpropagation method, and in fact the simplest of them all as it coincides with standard backpropagation. The simplest way to use this method is via the gradient() function:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
from torchray.attribution.gradient import gradient
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Gradient method.
saliency = gradient(model, x, category_id)

# Plots.
plot_example(x, saliency, 'gradient', category_id)

Alternatively, one can do so manually, as follows

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
from torchray.attribution.common import gradient_to_saliency
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Gradient method.

x.requires_grad_(True)
y = model(x)
z = y[0, category_id]
z.backward()

saliency = gradient_to_saliency(x)

# Plots.
plot_example(x, saliency, 'gradient', category_id)

Note that in this example, for visualization, the gradient is convernted into an image by postprocessing by using the function torchray.attribution.common.saliency().

See also Attribution for further examples.

References

GRAD

Karen Simonyan, Andrea Vedaldi and Andrew Zisserman, Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps, ICLR workshop, 2014, https://arxiv.org/abs/1312.6034.

torchray.attribution.gradient.gradient(*args, context_builder=None, **kwargs)[source]

Gradient method

The function takes the same arguments as common.saliency(), with the defaults required to apply the gradient method, and supports the same arguments and return values.

Grad-CAM

This module provides an implementation of the Grad-CAM method of [GRADCAM] for saliency visualization. The simplest interface is given by the grad_cam() function:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
from torchray.attribution.grad_cam import grad_cam
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Grad-CAM backprop.
saliency = grad_cam(model, x, category_id, saliency_layer='features.29')

# Plots.
plot_example(x, saliency, 'grad-cam backprop', category_id)

Alternatively, it is possible to run the method “manually”. Grad-CAM backprop is a variant of the gradient method, applied at an intermediate layer:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
from torchray.attribution.common import Probe, get_module
from torchray.attribution.grad_cam import gradient_to_grad_cam_saliency
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Grad-CAM backprop.
saliency_layer = get_module(model, 'features.29')

probe = Probe(saliency_layer, target='output')

y = model(x)
z = y[0, category_id]
z.backward()

saliency = gradient_to_grad_cam_saliency(probe.data[0])

# Plots.
plot_example(x, saliency, 'grad-cam backprop', category_id)

Note that the function gradient_to_grad_cam_saliency() is used to convert activations and gradients to a saliency map.

See also Attribution for further examples and discussion.

Theory

Grad-CAM can be seen as a variant of the gradient method (torchray.attribution.gradient) with two differences:

  1. The saliency is measured at an intermediate layer of the network, usually at the output of the last convolutional layer.

  2. Saliency is defined as the clamped product of forward activation and backward gradient at that layer.

References

GRADCAM

Ramprasaath R. Selvaraju, Abhishek Das, Ramakrishna Vedantam, Michael Cogswell, Devi Parikh and Dhruv Batra, Visual Explanations from Deep Networks via Gradient-based Localization, ICCV 2017, http://openaccess.thecvf.com/content_iccv_2017/html/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.html.

torchray.attribution.grad_cam.grad_cam(*args, saliency_layer, gradient_to_saliency=<function gradient_to_grad_cam_saliency>, **kwargs)[source]

Grad-CAM method.

The function takes the same arguments as common.saliency(), with the defaults required to apply the Grad-CAM method, and supports the same arguments and return values.

Guided backprop

This module implements guided backpropagation method of [GUIDED] or visualizing deep networks. The simplest interface is given by the guided_backprop() function:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
from torchray.attribution.guided_backprop import guided_backprop
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Guided backprop.
saliency = guided_backprop(model, x, category_id)

# Plots.
plot_example(x, saliency, 'guided backprop', category_id)

Alternatively, it is possible to run the method “manually”. Guided backprop is a backpropagation method, and thus works by changing the definition of the backward functions of some layers. This can be done using the GuidedBackpropContext context:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from torchray.attribution.common import gradient_to_saliency
from torchray.attribution.guided_backprop import GuidedBackpropContext
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Guided backprop.
x.requires_grad_(True)

with GuidedBackpropContext():
    y = model(x)
    z = y[0, category_id]
    z.backward()

saliency = gradient_to_saliency(x)

# Plots.
plot_example(x, saliency, 'guided backprop', category_id)

See also Attribution for further examples.

Theory

Guided backprop is a backpropagation method, and thus it works by changing the definition of the backward functions of some layers. The only change is a modified definition of the backward ReLU function:

\[\begin{split}\operatorname{ReLU}^*(x,p) = \begin{cases} p, & \mathrm{if}~p > 0 ~\mathrm{and}~ x > 0,\\ 0, & \mathrm{otherwise} \\ \end{cases}\end{split}\]

The modified ReLU is implemented by class GuidedBackpropReLU.

References

GUIDED

Springenberg et al., Striving for simplicity: The all convolutional net, ICLR Workshop 2015, https://arxiv.org/abs/1412.6806.

class torchray.attribution.guided_backprop.GuidedBackpropContext[source]

Bases: torchray.attribution.common.ReLUContext

GuidedBackprop context.

This context modifies the computation of gradients to match the guided backpropagaton definition.

See torchray.attribution.guided_backprop for how to use it.

torchray.attribution.guided_backprop.guided_backprop(*args, context_builder=<class 'torchray.attribution.guided_backprop.GuidedBackpropContext'>, **kwargs)[source]

Guided backprop.

The function takes the same arguments as common.saliency(), with the defaults required to apply the guided backprop method, and supports the same arguments and return values.

Linear approximation

This module provides an implementation of the linear approximation method for saliency visualization. The simplest interface is given by the linear_approx() function:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
from torchray.attribution.linear_approx import linear_approx
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Linear approximation backprop.
saliency = linear_approx(model, x, category_id, saliency_layer='features.29')

# Plots.
plot_example(x, saliency, 'linear approx', category_id)

Alternatively, it is possible to run the method “manually”. Linear approximation is a variant of the gradient method, applied at an intermediate layer:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
from torchray.attribution.common import Probe, get_module
from torchray.attribution.linear_approx import gradient_to_linear_approx_saliency
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Linear approximation.
saliency_layer = get_module(model, 'features.29')

probe = Probe(saliency_layer, target='output')

y = model(x)
z = y[0, category_id]
z.backward()

saliency = gradient_to_linear_approx_saliency(probe.data[0])

# Plots.
plot_example(x, saliency, 'linear approx', category_id)

Note that the function gradient_to_linear_approx_saliency() is used to convert activations and gradients to a saliency map.

torchray.attribution.linear_approx.gradient_to_linear_approx_saliency(x)[source]

Returns the linear approximation of a tensor.

The tensor x must have a valid gradient x.grad. The function then computes the saliency map \(s\): given by:

\[s_{n1u} = \sum_{c} x_{ncu} \cdot dx_{ncu}\]
Parameters

x (torch.Tensor) – activation tensor with a valid gradient.

Returns

Saliency map.

Return type

torch.Tensor

torchray.attribution.linear_approx.linear_approx(*args, gradient_to_saliency=<function gradient_to_linear_approx_saliency>, **kwargs)[source]

Linear approximation.

The function takes the same arguments as common.saliency(), with the defaults required to apply the linear approximation method, and supports the same arguments and return values.

RISE

This module provides an implementation of the RISE method of [RISE] for saliency visualization. This is given by the rise() function, which can be used as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
from torchray.attribution.rise import rise
from torchray.benchmark import get_example_data, plot_example
from torchray.utils import get_device

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Run on GPU if available.
device = get_device()
model.to(device)
x = x.to(device)

# RISE method.
saliency = rise(model, x)
saliency = saliency[:, category_id].unsqueeze(0)

# Plots.
plot_example(x, saliency, 'RISE', category_id)

References

RISE

V. Petsiuk, A. Das and K. Saenko RISE: Randomized Input Sampling for Explanation of Black-box Models, BMVC 2018, https://arxiv.org/pdf/1806.07421.pdf.

torchray.attribution.rise.rise(model, input, target=None, seed=0, num_masks=8000, num_cells=7, filter_masks=None, batch_size=32, p=0.5, resize=False, resize_mode='bilinear')[source]

RISE.

Parameters
  • model (torch.nn.Module) – a model.

  • input (torch.Tensor) – input tensor.

  • seed (int, optional) – manual seed used to generate random numbers. Default: 0.

  • num_masks (int, optional) – number of RISE random masks to use. Default: 8000.

  • num_cells (int, optional) – number of cells for one spatial dimension in low-res RISE random mask. Default: 7.

  • filter_masks (torch.Tensor, optional) – If given, use the provided pre-computed filter masks. Default: None.

  • batch_size (int, optional) – batch size to use. Default: 128.

  • p (float, optional) – with prob p, a low-res cell is set to 0; otherwise, it’s 1. Default: 0.5.

  • resize (bool or tuple of ints, optional) – If True, resize saliency map to size of input. If False, don’t resize. If (width, height) tuple, resize to (width, height). Default: False.

  • resize_mode (str, optional) – If resize is not None, use this mode for the resize function. Default: 'bilinear'.

Returns

RISE saliency map.

Return type

torch.Tensor

torchray.attribution.rise.rise_class(*args, target, **kwargs)[source]

Class-specific RISE.

This function has the all the arguments of rise() with the following additional argument and returns a class-specific saliency map for the given target class(es).

Parameters

target (int, torch.Tensor, list, or np.ndarray) – target label(s) that can be cast to torch.long.

Common code

This module defines common code for the backpropagation methods.

torchray.attribution.common.attach_debug_probes(model, debug=False)[source]

Returns an collections.OrderedDict of Probe objects for all modules in the model if debug is True; otherwise, returns None.

Parameters
  • model (torch.nn.Module) – a model.

  • debug (bool, optional) – if True, return an OrderedDict of Probe objects for all modules in the model; otherwise returns None. Default: False.

Returns

dict of Probe objects for

all modules in the model.

Return type

collections.OrderedDict

torchray.attribution.common.get_backward_gradient(pred_y, y)[source]

Returns a gradient tensor that is either equal to y (if y is a tensor with the same shape as pred_y) or a one-hot encoding in the channels dimension.

y can be either an int, an array-like list of integers, or a tensor. If y is a tensor with the same shape as pred_y, the function returns y unchanged.

Otherwise, y is interpreted as a list of class indices. These are first unfolded/expanded to one index per batch element in pred_y (i.e. along the first dimension). Then, this list is further expanded to all spatial dimensions of pred_y. (i.e. all but the first two dimensions of pred_y). Finally, the function return a “gradient” tensor that is a one-hot indicator tensor for these classes.

Parameters
  • pred_y (torch.Tensor) – model output tensor.

  • y (int, torch.Tensor, list, or np.ndarray) – target label(s) that can be cast to torch.long.

Returns

gradient tensor with the same shape as

pred_y.

Return type

torch.Tensor

torchray.attribution.common.get_module(model, module)[source]

Returns a specific layer in a model based.

module is either the name of a module (as given by the named_modules() function for torch.nn.Module objects) or a torch.nn.Module object. If module is a torch.nn.Module object, then module is returned unchanged. If module is a str, the function searches for a module with the name module and returns a torch.nn.Module if found; otherwise, None is returned.

Parameters
  • model (torch.nn.Module) – model in which to search for layer.

  • module (str or torch.nn.Module) – name of layer (str) or the layer itself (torch.nn.Module).

Returns

specific PyTorch layer (None if the layer

isn’t found).

Return type

torch.nn.Module

torchray.attribution.common.get_pointing_gradient(pred_y, y, normalize=True)[source]

Returns a gradient tensor for the pointing game.

Parameters
  • pred_y (torch.Tensor) – 4D tensor that the model outputs.

  • y (int) – target label.

  • normalize (bool) – If True, normalize the gradient tensor s.t. it sums to 1. Default: True.

Returns

gradient tensor with the same shape as pred_y.

Return type

torch.Tensor

torchray.attribution.common.gradient_to_saliency(x)[source]

Convert a gradient to a saliency map.

The tensor x must have a valid gradient x.grad. The function then computes the saliency map \(s\) given by:

\[s_{n,1,u} = \max_{0 \leq c < C} |dx_{ncu}|\]

where \(n\) is the instance index, \(c\) the channel index and \(u\) the spatial multi-index (usually of dimension 2 for images).

Parameters

x (Tensor) – activation with gradient.

Returns

saliency

Return type

Tensor

class torchray.attribution.common.Probe(module, target='input')[source]

Bases: object

Probe for a layer.

A probe attaches to a given torch.nn.Module instance. While attached, the object records any data produced by the module along with the corresponding gradients. Use remove() to remove the probe.

Examples

module = torch.nn.ReLU
probe = Probe(module)
x = torch.randn(1, 10)
y = module(x)
z = y.sum()
z.backward()
print(probe.data[0].shape)
print(probe.data[0].grad.shape)
remove()[source]

Remove the probe.

class torchray.attribution.common.Patch(target, new_callable)[source]

Bases: object

Patch a callable in a module.

remove()[source]

Remove the patch.

static resolve(target)[source]

Resolve a target into a module and an attribute.

The function resolves a string such as 'this.that.thing' into a module instance this.that (importing the module) and an attribute thing.

Parameters

target (str) – target string.

Returns

module, attribute.

Return type

tuple

class torchray.attribution.common.ReLUContext(relu_func)[source]

Bases: object

A context manager that replaces torch.relu() with

relu_function.

Parameters

relu_func (torch.autograd.function.FunctionMeta) – class definition of a torch.autograd.Function.

torchray.attribution.common.resize_saliency(tensor, saliency, size, mode)[source]

Resize a saliency map.

Parameters
  • tensor (torch.Tensor) – reference tensor.

  • saliency (torch.Tensor) – saliency map.

  • size (bool or tuple of int) – if a tuple (i.e., (width, height), resize saliency to size. If True, resize saliency: to the shape of :attr:`tensor; otherwise, return saliency unchanged.

  • mode (str) – mode for torch.nn.functional.interpolate().

Returns

Resized saliency map.

Return type

torch.Tensor

torchray.attribution.common.saliency(model, input, target, saliency_layer='', resize=False, resize_mode='bilinear', smooth=0, context_builder=<class 'torchray.attribution.common.NullContext'>, gradient_to_saliency=<function gradient_to_saliency>, get_backward_gradient=<function get_backward_gradient>, debug=False)[source]

Apply a backprop-based attribution method to an image.

The saliency method is specified by a suitable context factory context_builder. This context is used to modify the backpropagation algorithm to match a given visualization method. This:

  1. Attaches a probe to the output tensor of saliency_layer, which must be a layer in model. If no such layer is specified, it selects the input tensor to model.

  2. Uses the function get_backward_gradient to obtain a gradient for the output tensor of the model. This function is passed as input the output tensor as well as the parameter target. By default, the get_backward_gradient() function is used. The latter generates as gradient a one-hot vector selecting target, usually the index of the class predicted by model.

  3. Evaluates model on input and then computes the pseudo-gradient of the model with respect the selected tensor. This calculation is controlled by context_builder.

  4. Extract the pseudo-gradient at the selected tensor as a raw saliency map.

  5. Call gradient_to_saliency to obtain an actual saliency map. This defaults to gradient_to_saliency() that takes the maximum absolute value along the channel dimension of the pseudo-gradient tensor.

  6. Optionally resizes the saliency map thus obtained. By default, this uses bilinear interpolation and resizes the saliency to the same spatial dimension of input.

  7. Optionally applies a Gaussian filter to the resized saliency map. The standard deviation sigma of this filter is measured as a fraction of the maxmum spatial dimension of the resized saliency map.

  8. Removes the probe.

  9. Returns the saliency map or optionally a tuple with the saliency map and a OrderedDict of Probe objects for all modules in the model, which can be used for debugging.

Parameters
  • model (torch.nn.Module) – a model.

  • input (torch.Tensor) – input tensor.

  • target (int or torch.Tensor) – target label(s).

  • saliency_layer (str or torch.nn.Module, optional) – name of the saliency layer (str) or the layer itself (torch.nn.Module) in the model at which to visualize. Default: '' (visualize at input).

  • resize (bool or tuple, optional) – if True, upsample saliency map to the same size as input. It is also possible to specify a pair (width, height) for a different size. Default: False.

  • resize_mode (str, optional) – upsampling method to use. Default: 'bilinear'.

  • smooth (float, optional) – amount of Gaussian smoothing to apply to the saliency map. Default: 0.

  • context_builder (type, optional) – type of context to use. Default: NullContext.

  • gradient_to_saliency (function, optional) – function that converts the pseudo-gradient signal to a saliency map. Default: gradient_to_saliency().

  • get_backward_gradient (function, optional) – function that generates gradient tensor to backpropagate. Default: get_backward_gradient().

  • debug (bool, optional) – if True, also return an collections.OrderedDict of Probe objects for all modules in the model. Default: False.

Returns

If debug is False, returns a torch.Tensor saliency map at saliency_layer. Otherwise, returns a tuple of a torch.Tensor saliency map at saliency_layer and an collections.OrderedDict of Probe objects for all modules in the model.

Return type

torch.Tensor or tuple