Source code for torchray.attribution.common

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

r"""
This module defines common code for the backpropagation methods.
"""

import torch
import torch.nn.functional as F
import weakref
from collections import OrderedDict
from packaging import version

from torchray.utils import imsmooth

__all__ = [
    'attach_debug_probes',
    'get_backward_gradient',
    'get_module',
    'get_pointing_gradient',
    'gradient_to_saliency',
    'Probe',
    'Patch',
    'NullContext',
    'ReLUContext',
    'resize_saliency',
    'saliency'
]

# Certain algorithms fail to work properly in earlier versions.
assert version.parse(torch.__version__) >= version.parse('1.1'), \
    'PyTorch 1.1 or above required.'


[docs]class Patch(object): """Patch a callable in a module."""
[docs] @staticmethod def resolve(target): """Resolve a target into a module and an attribute. The function resolves a string such as ``'this.that.thing'`` into a module instance `this.that` (importing the module) and an attribute `thing`. Args: target (str): target string. Returns: tuple: module, attribute. """ target, attribute = target.rsplit('.', 1) components = target.split('.') import_path = components.pop(0) target = __import__(import_path) for comp in components: import_path += '.{}'.format(comp) __import__(import_path) target = getattr(target, comp) return target, attribute
def __init__(self, target, new_callable): """Patch a callable in a module. Args: target (str): path to the callable to patch. callable (fun): new callable. """ target, attribute = Patch.resolve(target) self.target = target self.attribute = attribute self.orig_callable = getattr(target, attribute) setattr(target, attribute, new_callable) def __del__(self): self.remove()
[docs] def remove(self): """Remove the patch.""" if self.target is not None: setattr(self.target, self.attribute, self.orig_callable) self.target = None
[docs]class ReLUContext(object): """ A context manager that replaces :func:`torch.relu` with :attr:`relu_function`. Args: relu_func (:class:`torch.autograd.function.FunctionMeta`): class definition of a :class:`torch.autograd.Function`. """ def __init__(self, relu_func): assert isinstance(relu_func, torch.autograd.function.FunctionMeta) self.relu_func = relu_func self.patches = [] def __enter__(self): relu = self.relu_func().apply self.patches = [ Patch('torch.relu', relu), Patch('torch.relu_', relu), ] return self def __exit__(self, type, value, traceback): for p in self.patches: p.remove() return False # re-raise any exception
def _wrap_in_list(x): if isinstance(x, list): return x elif isinstance(x, tuple): return list(x) else: return [x] class _InjectContrast(object): def __init__(self, contrast, non_negative): self.contrast = contrast self.non_negative = non_negative def __call__(self, grad): assert grad.shape == self.contrast.shape delta = grad - self.contrast if self.non_negative: delta = delta.clamp(min=0) return delta class _Catch(object): def __init__(self, probe): self.probe = weakref.ref(probe) def _process_data(self, data): if not self.probe(): return p = self.probe() assert isinstance(data, list) p.data = data for i, x in enumerate(p.data): x.requires_grad_(True) x.retain_grad() if len(p.contrast) > i and p.contrast[i] is not None: injector = _InjectContrast( p.contrast[i], p.non_negative_contrast) x.register_hook(injector) class _CatchInputs(_Catch): def __call__(self, module, input): self._process_data(_wrap_in_list(input)) class _CatchOutputs(_Catch): def __call__(self, module, input, output): self._process_data(_wrap_in_list(output))
[docs]class Probe(object): """Probe for a layer. A probe attaches to a given :class:`torch.nn.Module` instance. While attached, the object records any data produced by the module along with the corresponding gradients. Use :func:`remove` to remove the probe. Examples: .. code:: python module = torch.nn.ReLU probe = Probe(module) x = torch.randn(1, 10) y = module(x) z = y.sum() z.backward() print(probe.data[0].shape) print(probe.data[0].grad.shape) """ def __init__(self, module, target='input'): """Create a probe attached to the specified module. The probe intercepts calls to the module on the way forward, capturing by default all the input activation tensor with their gradients. The activation tensors are stored as a sequence :attr:`data`. Args: module (torch.nn.Module): Module to attach. target (str): Choose from ``'input'`` or ``'output'``. Use ``'output'`` to intercept the outputs of a module instead of the inputs into the module. Default: ``'input'``. .. Warning: PyTorch module interface (at least until 1.1.0) is partially broken. In particular, the hook functionality used by the probe work properly only for atomic module, not for containers such as sequences or for complex module that run several functions internally. """ self.module = module self.data = [] self.target = target self.hook = None self.contrast = [] self.non_negative_contrast = False if hasattr(self.module, "inplace"): self.inplace = self.module.inplace self.module.inplace = False if self.target == 'input': self.hook = module.register_forward_pre_hook(_CatchInputs(self)) elif self.target == 'output': self.hook = module.register_forward_hook(_CatchOutputs(self)) else: assert False def __del__(self): self.remove()
[docs] def remove(self): """Remove the probe.""" if self.module is not None: if hasattr(self.module, "inplace"): self.module.inplace = self.inplace self.hook.remove() self.module = None
class NullContext(object): def __init__(self): r"""Null context. This context does nothing. """ def __enter__(self): return self def __exit__(self, type, value, traceback): return False
[docs]def get_pointing_gradient(pred_y, y, normalize=True): """Returns a gradient tensor for the pointing game. Args: pred_y (:class:`torch.Tensor`): 4D tensor that the model outputs. y (int): target label. normalize (bool): If True, normalize the gradient tensor s.t. it sums to 1. Default: ``True``. Returns: :class:`torch.Tensor`: gradient tensor with the same shape as :attr:`pred_y`. """ assert isinstance(pred_y, torch.Tensor) assert len(pred_y.shape) == 4 or len(pred_y.shape) == 2 assert pred_y.shape[0] == 1 assert isinstance(y, int) backward_gradient = torch.zeros_like(pred_y) backward_gradient[0, y] = torch.exp(pred_y[0, y]) if normalize: backward_gradient[0, y] /= backward_gradient[0, y].sum() return backward_gradient
[docs]def get_backward_gradient(pred_y, y): r""" Returns a gradient tensor that is either equal to :attr:`y` (if y is a tensor with the same shape as pred_y) or a one-hot encoding in the channels dimension. :attr:`y` can be either an ``int``, an array-like list of integers, or a tensor. If :attr:`y` is a tensor with the same shape as :attr:`pred_y`, the function returns :attr:`y` unchanged. Otherwise, :attr:`y` is interpreted as a list of class indices. These are first unfolded/expanded to one index per batch element in :attr:`pred_y` (i.e. along the first dimension). Then, this list is further expanded to all spatial dimensions of :attr:`pred_y`. (i.e. all but the first two dimensions of :attr:`pred_y`). Finally, the function return a "gradient" tensor that is a one-hot indicator tensor for these classes. Args: pred_y (:class:`torch.Tensor`): model output tensor. y (int, :class:`torch.Tensor`, list, or :class:`np.ndarray`): target label(s) that can be cast to :class:`torch.long`. Returns: :class:`torch.Tensor`: gradient tensor with the same shape as :attr:`pred_y`. """ assert isinstance(pred_y, torch.Tensor) if not isinstance(y, torch.Tensor): y = torch.tensor(y, dtype=torch.long, device=pred_y.device) assert isinstance(y, torch.Tensor) if y.shape == pred_y.shape: return y assert y.dtype == torch.long nspatial = len(pred_y.shape) - 2 grad = torch.zeros_like(pred_y) y = y.reshape(-1, 1, *((1,) * nspatial)).expand_as(grad) grad.scatter_(1, y, 1.) return grad
[docs]def get_module(model, module): r"""Returns a specific layer in a model based. :attr:`module` is either the name of a module (as given by the :func:`named_modules` function for :class:`torch.nn.Module` objects) or a :class:`torch.nn.Module` object. If :attr:`module` is a :class:`torch.nn.Module` object, then :attr:`module` is returned unchanged. If :attr:`module` is a str, the function searches for a module with the name :attr:`module` and returns a :class:`torch.nn.Module` if found; otherwise, ``None`` is returned. Args: model (:class:`torch.nn.Module`): model in which to search for layer. module (str or :class:`torch.nn.Module`): name of layer (str) or the layer itself (:class:`torch.nn.Module`). Returns: :class:`torch.nn.Module`: specific PyTorch layer (``None`` if the layer isn't found). """ if isinstance(module, torch.nn.Module): return module assert isinstance(module, str) if module == '': return model for name, curr_module in model.named_modules(): if name == module: return curr_module return None
[docs]def gradient_to_saliency(x): r"""Convert a gradient to a saliency map. The tensor :attr:`x` must have a valid gradient ``x.grad``. The function then computes the saliency map :math:`s` given by: .. math:: s_{n,1,u} = \max_{0 \leq c < C} |dx_{ncu}| where :math:`n` is the instance index, :math:`c` the channel index and :math:`u` the spatial multi-index (usually of dimension 2 for images). Args: x (Tensor): activation with gradient. Return: Tensor: saliency """ return x.grad.abs().max(dim=1, keepdim=True)[0]
[docs]def resize_saliency(tensor, saliency, size, mode): """Resize a saliency map. Args: tensor (:class:`torch.Tensor`): reference tensor. saliency (:class:`torch.Tensor`): saliency map. size (bool or tuple of int): if a tuple (i.e., (width, height), resize :attr:`saliency` to :attr:`size`. If True, resize :attr:`saliency: to the shape of :attr:`tensor`; otherwise, return :attr:`saliency` unchanged. mode (str): mode for :func:`torch.nn.functional.interpolate`. Returns: :class:`torch.Tensor`: Resized saliency map. """ if size is not False: if size is True: size = tensor.shape[2:] elif isinstance(size, tuple) or isinstance(size, list): # width, height -> height, width size = size[::-1] else: assert False, "resize must be True, False or a tuple." saliency = F.interpolate( saliency, size, mode=mode, align_corners=False) return saliency
[docs]def attach_debug_probes(model, debug=False): r""" Returns an :class:`collections.OrderedDict` of :class:`Probe` objects for all modules in the model if :attr:`debug` is ``True``; otherwise, returns ``None``. Args: model (:class:`torch.nn.Module`): a model. debug (bool, optional): if True, return an OrderedDict of Probe objects for all modules in the model; otherwise returns ``None``. Default: ``False``. Returns: :class:`collections.OrderedDict`: dict of :class:`Probe` objects for all modules in the model. """ if not debug: return None debug_probes = OrderedDict() for module_name, module in model.named_modules(): debug_probe_target = "input" if module_name == "" else "output" debug_probes[module_name] = Probe( module, target=debug_probe_target) return debug_probes
[docs]def saliency(model, input, target, saliency_layer='', resize=False, resize_mode='bilinear', smooth=0, context_builder=NullContext, gradient_to_saliency=gradient_to_saliency, get_backward_gradient=get_backward_gradient, debug=False): """Apply a backprop-based attribution method to an image. The saliency method is specified by a suitable context factory :attr:`context_builder`. This context is used to modify the backpropagation algorithm to match a given visualization method. This: 1. Attaches a probe to the output tensor of :attr:`saliency_layer`, which must be a layer in :attr:`model`. If no such layer is specified, it selects the input tensor to :attr:`model`. 2. Uses the function :attr:`get_backward_gradient` to obtain a gradient for the output tensor of the model. This function is passed as input the output tensor as well as the parameter :attr:`target`. By default, the :func:`get_backward_gradient` function is used. The latter generates as gradient a one-hot vector selecting :attr:`target`, usually the index of the class predicted by :attr:`model`. 3. Evaluates :attr:`model` on :attr:`input` and then computes the pseudo-gradient of the model with respect the selected tensor. This calculation is controlled by :attr:`context_builder`. 4. Extract the pseudo-gradient at the selected tensor as a raw saliency map. 5. Call :attr:`gradient_to_saliency` to obtain an actual saliency map. This defaults to :func:`gradient_to_saliency` that takes the maximum absolute value along the channel dimension of the pseudo-gradient tensor. 6. Optionally resizes the saliency map thus obtained. By default, this uses bilinear interpolation and resizes the saliency to the same spatial dimension of :attr:`input`. 7. Optionally applies a Gaussian filter to the resized saliency map. The standard deviation :attr:`sigma` of this filter is measured as a fraction of the maxmum spatial dimension of the resized saliency map. 8. Removes the probe. 9. Returns the saliency map or optionally a tuple with the saliency map and a OrderedDict of Probe objects for all modules in the model, which can be used for debugging. Args: model (:class:`torch.nn.Module`): a model. input (:class:`torch.Tensor`): input tensor. target (int or :class:`torch.Tensor`): target label(s). saliency_layer (str or :class:`torch.nn.Module`, optional): name of the saliency layer (str) or the layer itself (:class:`torch.nn.Module`) in the model at which to visualize. Default: ``''`` (visualize at input). resize (bool or tuple, optional): if True, upsample saliency map to the same size as :attr:`input`. It is also possible to specify a pair (width, height) for a different size. Default: ``False``. resize_mode (str, optional): upsampling method to use. Default: ``'bilinear'``. smooth (float, optional): amount of Gaussian smoothing to apply to the saliency map. Default: ``0``. context_builder (type, optional): type of context to use. Default: :class:`NullContext`. gradient_to_saliency (function, optional): function that converts the pseudo-gradient signal to a saliency map. Default: :func:`gradient_to_saliency`. get_backward_gradient (function, optional): function that generates gradient tensor to backpropagate. Default: :func:`get_backward_gradient`. debug (bool, optional): if True, also return an :class:`collections.OrderedDict` of :class:`Probe` objects for all modules in the model. Default: ``False``. Returns: :class:`torch.Tensor` or tuple: If :attr:`debug` is False, returns a :class:`torch.Tensor` saliency map at :attr:`saliency_layer`. Otherwise, returns a tuple of a :class:`torch.Tensor` saliency map at :attr:`saliency_layer` and an :class:`collections.OrderedDict` of :class:`Probe` objects for all modules in the model. """ # Clear any existing gradient. if input.grad is not None: input.grad.data.zero_() # Disable gradients for model parameters. orig_requires_grad = {} for name, param in model.named_parameters(): orig_requires_grad[name] = param.requires_grad param.requires_grad_(False) # Set model to eval mode. if model.training: orig_is_training = True model.eval() else: orig_is_training = False # Attach debug probes to every module. debug_probes = attach_debug_probes(model, debug=debug) # Attach a probe to the saliency layer. probe_target = 'input' if saliency_layer == '' else 'output' saliency_layer = get_module(model, saliency_layer) assert saliency_layer is not None, 'We could not find the saliency layer' probe = Probe(saliency_layer, target=probe_target) # Do a forward and backward pass. with context_builder(): output = model(input) backward_gradient = get_backward_gradient(output, target) output.backward(backward_gradient) # Get saliency map from gradient. saliency_map = gradient_to_saliency(probe.data[0]) # Resize saliency map. saliency_map = resize_saliency(input, saliency_map, resize, mode=resize_mode) # Smooth saliency map. if smooth > 0: saliency_map = imsmooth( saliency_map, sigma=smooth * max(saliency_map.shape[2:]), padding_mode='replicate' ) # Remove probe. probe.remove() # Restore gradient saving for model parameters. for name, param in model.named_parameters(): param.requires_grad_(orig_requires_grad[name]) # Restore model's original mode. if orig_is_training: model.train() if debug: return saliency_map, debug_probes else: return saliency_map