# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
r"""
This module provides an implementation of the *RISE* method of [RISE]_ for
saliency visualization. This is given by the :func:`rise` function, which
can be used as follows:
.. literalinclude:: ../examples/rise.py
:language: python
:linenos:
References:
.. [RISE] V. Petsiuk, A. Das and K. Saenko
*RISE: Randomized Input Sampling for Explanation of Black-box
Models,*
BMVC 2018,
`<https://arxiv.org/pdf/1806.07421.pdf>`__.
"""
__all__ = ['rise', 'rise_class']
import numpy as np
import torch
import torch.nn.functional as F
from .common import resize_saliency
def _upsample_reflect(x, size, interpolate_mode="bilinear"):
r"""Upsample 4D :class:`torch.Tensor` with reflection padding.
Args:
x (:class:`torch.Tensor`): 4D tensor to interpolate.
size (int or list or tuple of ints): target size
interpolate_mode (str): mode to pass to
:function:`torch.nn.functional.interpolate` function call
(default: "bilinear").
Returns:
:class:`torch.Tensor`: upsampled tensor.
"""
# Check and get input size.
assert len(x.shape) == 4
orig_size = x.shape[2:]
# Check target size.
if not isinstance(size, tuple) and not isinstance(size, list):
assert isinstance(size, int)
size = (size, size)
assert len(size) == 2
# Ensure upsampling.
for i, o_s in enumerate(orig_size):
assert o_s <= size[i]
# Get size of input cell when interpolated.
cell_size = [int(np.ceil(s / orig_size[i])) for i, s in enumerate(size)]
# Get size of interpolated input with padding.
pad_size = [int(cell_size[i] * (orig_size[i] + 2))
for i in range(len(orig_size))]
# Pad input with reflection padding.
x_padded = F.pad(x, (1, 1, 1, 1), mode="reflect")
# Interpolated padded input.
x_up = F.interpolate(x_padded,
pad_size,
mode=interpolate_mode,
align_corners=False)
# Slice interpolated input to size.
x_new = x_up[:,
:,
cell_size[0]:cell_size[0] + size[0],
cell_size[1]:cell_size[1] + size[1]]
return x_new
[docs]def rise_class(*args, target, **kwargs):
r"""Class-specific RISE.
This function has the all the arguments of :func:`rise` with the following
additional argument and returns a class-specific saliency map for the
given :attr:`target` class(es).
Args:
target (int, :class:`torch.Tensor`, list, or :class:`np.ndarray`):
target label(s) that can be cast to :class:`torch.long`.
"""
saliency = rise(*args, **kwargs)
assert len(saliency.shape) == 4
if not isinstance(target, torch.Tensor):
target = torch.tensor(target, dtype=torch.long, device=saliency.device)
assert isinstance(target, torch.Tensor)
assert target.dtype == torch.long
assert len(target) == len(saliency)
class_saliency = torch.cat([saliency[i, t].unsqueeze(0).unsqueeze(1)
for i, t in enumerate(target)], dim=0)
output_shape = list(saliency.shape)
output_shape[1] = 1
assert list(class_saliency.shape) == output_shape
return class_saliency
[docs]def rise(model,
input,
target=None,
seed=0,
num_masks=8000,
num_cells=7,
filter_masks=None,
batch_size=32,
p=0.5,
resize=False,
resize_mode='bilinear'):
r"""RISE.
Args:
model (:class:`torch.nn.Module`): a model.
input (:class:`torch.Tensor`): input tensor.
seed (int, optional): manual seed used to generate random numbers.
Default: ``0``.
num_masks (int, optional): number of RISE random masks to use.
Default: ``8000``.
num_cells (int, optional): number of cells for one spatial dimension
in low-res RISE random mask. Default: ``7``.
filter_masks (:class:`torch.Tensor`, optional): If given, use the
provided pre-computed filter masks. Default: ``None``.
batch_size (int, optional): batch size to use. Default: ``128``.
p (float, optional): with prob p, a low-res cell is set to 0;
otherwise, it's 1. Default: ``0.5``.
resize (bool or tuple of ints, optional): If True, resize saliency map
to size of :attr:`input`. If False, don't resize. If (width,
height) tuple, resize to (width, height). Default: ``False``.
resize_mode (str, optional): If resize is not None, use this mode for
the resize function. Default: ``'bilinear'``.
Returns:
:class:`torch.Tensor`: RISE saliency map.
"""
with torch.no_grad():
# Get device of input (i.e., GPU).
dev = input.device
# Initialize saliency mask and mask normalization term.
input_shape = input.shape
saliency_shape = list(input_shape)
height = input_shape[2]
width = input_shape[3]
out = model(input)
num_classes = out.shape[1]
saliency_shape[1] = num_classes
saliency = torch.zeros(saliency_shape, device=dev)
# Number of spatial dimensions.
nsd = len(input.shape) - 2
assert nsd == 2
# Spatial size of low-res grid cell.
cell_size = tuple([int(np.ceil(s / num_cells))
for s in input_shape[2:]])
# Spatial size of upsampled mask with buffer (input size + cell size).
up_size = tuple([input_shape[2 + i] + cell_size[i]
for i in range(nsd)])
# Save current random number generator state.
state = torch.get_rng_state()
# Set seed.
torch.manual_seed(seed)
if filter_masks is not None:
assert len(filter_masks) == num_masks
num_chunks = (num_masks + batch_size - 1) // batch_size
for chunk in range(num_chunks):
# Generate RISE random masks on the fly.
mask_bs = min(num_masks - batch_size * chunk, batch_size)
if filter_masks is None:
# Generate low-res, random binary masks.
grid = (torch.rand(mask_bs, 1, *((num_cells,) * nsd),
device=dev) < p).float()
# Upsample low-res masks to input shape + buffer.
masks_up = _upsample_reflect(grid, up_size)
# Save final RISE masks with random shift.
masks = torch.empty(mask_bs, 1, *input_shape[2:], device=dev)
shift_x = torch.randint(0,
cell_size[0],
(mask_bs,),
device='cpu')
shift_y = torch.randint(0,
cell_size[1],
(mask_bs,),
device='cpu')
for i in range(mask_bs):
masks[i] = masks_up[i,
:,
shift_x[i]:shift_x[i] + height,
shift_y[i]:shift_y[i] + width]
else:
masks = filter_masks[
chunk * batch_size:chunk * batch_size + mask_bs]
# Accumulate saliency mask.
for i, inp in enumerate(input):
out = torch.sigmoid(model(inp.unsqueeze(0) * masks))
if len(out.shape) == 4:
# TODO: Consider handling FC outputs more flexibly.
assert out.shape[2] == 1
assert out.shape[3] == 1
out = out[:, :, 0, 0]
sal = torch.matmul(out.data.transpose(0, 1),
masks.view(mask_bs, height * width))
sal = sal.view((num_classes, height, width))
saliency[i] = saliency[i] + sal
# Normalize saliency mask.
saliency /= num_masks
# Restore original random number generator state.
torch.set_rng_state(state)
# Resize saliency mask if needed.
saliency = resize_saliency(input,
saliency,
resize,
mode=resize_mode)
return saliency