Benchmarking¶
This module contains code for benchmarking attribution methods, including
reproducing several published results. In addition to implementations of
benchmarking protocols (pointing_game
), the module also provides
implementations of reference datasets and reference models used in prior
research work, properly converted to PyTorch. Overall, this implementations
closely reproduces prior results, notably the ones in the [EBP] paper.
A standard benchmarking suite is included in this library as
examples.standard_suite
. For slow methods, a computer cluster may be
required for evaluation (we do not include explicit support for clusters, but
it is easy to add on top of this example code).
It is also recommended to turn on logging (see
torchray.benchmark.logging
), which allows the driver to
uses MongoDB to store partial benchmarking results as it goes.
Computations can then be cached and reused to resume the calculations
after a crash or other issue. In order to start the logging server, use
$ python -m torchray.benchmark.server
The server parameters (address, port, etc) can be configured by writing
a .torchrayrc
file in your current or home directory. The package
contains an example configuration file. The server creates a regular
MongoDB database (by default in ./data/db
) which can be manually
explored by means of the MongoDB shell.
By default, the driver writes data in the ./data/
subfolder.
You can change that via the configuration file, or, possibly more easily,
add a symbolic link to where you want to store the data.
The data include the datasets (PASCAL VOC, COCO, ImageNet; see
torchray.benchmark.datasets
). These must be downloaded manually and
stored in ./data/datasets/{voc,coco,imagenet}
unless this is changed via
the configuration file. Note that these datasets can be very large (many GBs).
The data also include reference models (see
torchray.benchmark.models
).
This script provides a few functions for getting and plotting example data.
-
torchray.benchmark.
get_example_data
(arch='vgg16', shape=224)[source]¶ Get example data to demonstrate visualization techniques.
- Parameters
arch (str, optional) – name of torchvision.models architecture. Default:
'vgg16'
.shape (int or tuple of int, optional) – shape to resize input image to. Default:
224
.
- Returns
a tuple containing
a convolutional neural network model in evaluation mode.
a sample input tensor image.
the ImageNet category id of an object in the image.
the ImageNet category id of another object in the image.
- Return type
(
torch.nn.Module
,torch.Tensor
, int, int)
-
torchray.benchmark.
plot_example
(input, saliency, method, category_id, show_plot=False, save_path=None)[source]¶ Plot an example.
- Parameters
input (
torch.Tensor
) – 4D tensor containing input images.saliency (
torch.Tensor
) – 4D tensor containing saliency maps.method (str) – name of saliency method.
category_id (int) – ID of ImageNet category.
show_plot (bool, optional) – If True, show plot. Default:
False
.save_path (str, optional) – Path to save figure to. Default:
None
.
Pointing Game¶
The Pointing Game [EBP] assesses the quality of an attribution method by testing how well it can extract from a predictor a response correlated with the presence of known object categories in the image.
Given an input image \(x\) containing an object of category \(c\), the attribution method is applied to the predictor in order to find the part of the images responsible for predicting \(c\). The attribution method usually returns a saliency heatmap. The latter must then be converted in a single point \((u,v)\) that is “most likely” to be contained by an object of that class. The specific way the point is obtained is method-dependent.
The attribution method then scores a hit if the point is within a tolerance \(\tau\) (set to 15 pixels by default) to the image region \(\Omega\) containing that object:
\[\operatorname{hit}(u,v|\Omega) = [ \exists (u',v') \in \Omega : \|(u,v) - (u',v')\| \leq \tau].\]
The point coordinates \((u,v)\) are also indices \(x_{ncvu}\) in the input image tensor \(x\).
RISE [RISE] and Extremal Perturbation [EP] results are averaged over 3 runs.
voc_2007 |
voc_2007 |
voc_2007 |
voc_2007 |
coco |
coco |
coco |
coco |
|
---|---|---|---|---|---|---|---|---|
vgg16 |
vgg16 |
resnet50 |
resnet50 |
vgg16 |
vgg16 |
resnet50 |
resnet50 |
|
center |
69.6 |
42.4 |
69.6 |
42.4 |
27.8 |
19.5 |
27.8 |
19.5 |
gradient |
76.3 |
56.9 |
72.3 |
56.8 |
37.7 |
31.4 |
35.0 |
29.4 |
deconvnet |
67.5 |
44.2 |
68.6 |
44.7 |
30.7 |
23.0 |
30.0 |
21.9 |
guided_backprop |
75.9 |
53.0 |
77.2 |
59.4 |
39.1 |
31.4 |
42.1 |
35.3 |
excitation_backprop |
77.1 |
56.6 |
84.5 |
70.8 |
39.8 |
32.8 |
49.6 |
43.9 |
contrastive_excitation_backprop |
79.9 |
66.5 |
90.7 |
82.1 |
49.7 |
44.3 |
58.5 |
53.6 |
rise |
86.9 |
75.1 |
86.4 |
78.8 |
50.8 |
45.3 |
54.7 |
50.0 |
grad_cam |
86.6 |
74.0 |
90.4 |
82.3 |
54.2 |
49.0 |
57.3 |
52.3 |
extremal_perturbation |
88.0 |
76.1 |
88.9 |
78.7 |
51.5 |
45.9 |
56.5 |
51.5 |
The pointing_game
modules implements the pointing game benchmark.
The basic benchmark is implemented by the PointingGame
class. However,
for benchmarking purposes it is recommended to use the wrapper class
PointingGameBenchmark
instead. This class supports PASCAL VOC 2007
test and COCO 2014 val with the modifications used in [EBP], including the
ability to run on their “difficult” subsets as defined in the original paper.
The class can be used as follows:
Obtain a dataset (usually COCO or PASCAL VOC) and choose a subset.
Initialize an instance of
PointingGameBenchmark
.For each image in the dataset:
For each class in the image:
Run the attribution method, usually resulting in a saliency map for class \(c\).
Convert the result to a point, usually by finding the maximizer of the saliency map.
Use the
PointingGameBenchmark.evaluate()
function to run the test and accumulate the statistics.
Extract the
PointingGame.hits
andPointingGame.misses
orprint
the instance to display the results.
-
class
torchray.benchmark.pointing_game.
PointingGame
(num_classes, tolerance=15)[source]¶ Bases:
object
Pointing game.
- Parameters
num_classes (int) – number of classes in the dataset.
tolerance (int, optional) – tolerance (in pixels) of margin around ground truth annotation. Default: 15.
-
hits
¶ num_classes
-dimensional vector of hits counts.- Type
torch.Tensor
-
misses
¶ num_classes
-dimensional vector of misses counts.- Type
torch.Tensor
-
property
accuracy
¶ mean accuracy, computed by averaging
class_accuracies
.- Type
(
torch.Tensor
)
-
property
class_accuracies
¶ num_classes
-dimensional vector containing per-class accuracy.- Type
(
torch.Tensor
)
-
evaluate
(mask, point)[source]¶ Evaluate a point prediction.
The function tests whether the prediction
point
is within a certain tolerance of the object ground-truth regionmask
expressed as a boolean occupancy map.Use the
reset()
method to clear all counters.- Parameters
mask (
torch.Tensor
) – \(\{0,1\}^{H\times W}\).point (tuple of ints) – predicted point \((u,v)\).
- Returns
+1 if the point hits the object; otherwise -1.
- Return type
int
-
class
torchray.benchmark.pointing_game.
PointingGameBenchmark
(dataset, tolerance=15, difficult=False)[source]¶ Bases:
torchray.benchmark.pointing_game.PointingGame
Pointing game benchmark on standard datasets.
The pointing game should be initialized with a dataset, set to either:
(
torchvision.VOCDetection
) VOC 2007 test subset.(
torchvision.CocoDetection
) COCO val2014 subset.
- Parameters
dataset (
torchvision.VisionDataset
) – The dataset.tolerance (int) – the tolerance for the pointing game. Default:
15
.difficult (bool) – whether to use the difficult subset. Default:
False
.
-
evaluate
(label, class_id, point)[source]¶ Evaluate an label-class-point triplet.
- Parameters
label (dict) – a label in VOC or Coco detection format.
class_id (int) – a class id.
point (iterable) – a point specified as a pair of u, v coordinates.
- Returns
- +1 if the point hits the object, -1 if the point misses the
object, and 0 if the point is skipped during evaluation.
- Return type
int
Datasets¶
This module provides a number of benchmark datasets:
ImageNet ILSVCR 12 and other image folders datasets (
ImageFolder
).PASCAL VOC (
VOCDetection
).MS COCO (
CocoDetection
).
The classes in this module extend corresponding classes in
torchvision.datasets
with functions for converting labels in various
formats and similar. Some of these functions are also provided as
“stand alone”.
-
torchray.benchmark.datasets.
IMAGENET_CLASSES
¶ List of the 1000 ImageNet ILSVRC class names.
-
torchray.benchmark.datasets.
VOC_CLASSES
¶ List of the 20 PASCAL VOC class names.
-
torchray.benchmark.datasets.
COCO_CLASSES
¶ List of the 80 COCO class names.
-
torchray.benchmark.datasets.
coco_as_class_ids
(label)[source]¶ Convert a COCO detection label to the list of class IDs.
- Parameters
label (list of dict) – an image label in the VOC detection format.
- Returns
List of ids of classes in the image.
- Return type
list
-
torchray.benchmark.datasets.
coco_as_image_size
(dataset, label)[source]¶ Convert a COCO detection label to the image size.
- Parameters
label (list of dict) – an image label in the VOC detection format.
- Returns
width, height of image.
- Return type
tuple
-
torchray.benchmark.datasets.
coco_as_mask
(dataset, label, class_id)[source]¶ Convert a COCO detection label to a mask.
Return a boolean mask for the regions of
class_id
.If the label is the empty list, because there are no objects at all in the image, the function returns
None
.- Parameters
label (array of dict) – an image label in the VOC detection format.
class_id (int) – ID of the requested class.
- Returns
2D boolean tensor.
- Return type
torch.Tensor
-
torchray.benchmark.datasets.
voc_as_class_ids
(label)[source]¶ Convert a VOC detection label to the list of class IDs.
- Parameters
label (dict) – an image label in the VOC detection format.
- Returns
List of ids of classes in the image.
- Return type
list
-
torchray.benchmark.datasets.
voc_as_image_size
(label)[source]¶ Convert a VOC detection label to the image size.
- Parameters
label (dict) – an image label in the VOC detection format.
- Returns
width, height of image.
- Return type
tuple
-
torchray.benchmark.datasets.
voc_as_mask
(label, class_id)[source]¶ Convert a VOC detection label to a mask.
Return a boolean mask selecting the region contained in the bounding boxes of
class_id
.- Parameters
label (dict) – an image label in the VOC detection format.
class_id (int) – ID of the requested class.
- Returns
2D boolean tensor.
- Return type
torch.Tensor
-
class
torchray.benchmark.datasets.
ImageFolder
(*args, limiter=None, full_classes=None, **kwargs)[source]¶ Bases:
torchvision.datasets.folder.ImageFolder
Image folder dataset.
This class extends
torchvision.datasets.ImageFolder
. Its constructor supports the following additional arguments:- Parameters
limiter (int, optional) – limit the dataset to
limiter
images, picking from each class in a round-robin fashion. Default:None
.full_classes (list of str, optional) – list of full class names. Default:
None
.
-
selection
¶ indices of the active images.
- Type
list of int
-
full_classes
¶ class names.
- Type
list of str
-
class
torchray.benchmark.datasets.
VOCDetection
(*args, limiter=None, **kwargs)[source]¶ Bases:
torchvision.datasets.voc.VOCDetection
PASCAL VOC Detection dataset.
This class extends
torchvision.datasets.VOCDetection
. Its constructor supports the following additional arguments:- Parameters
limiter (int, optional) – limit the dataset to the first
limiter
images. Default:None
.
-
selection
¶ indices of the active images.
- Type
list of int
-
classes
¶ class names.
- Type
list of str
-
as_class_ids
(label)[source]¶ Convert a label to list of class IDs.
The same as
voc_as_class_ids()
.
-
as_image_size
(label)[source]¶ Convert a label to the image size.
The same as
voc_as_image_size()
.
-
as_mask
(label, class_id)[source]¶ Convert a label to a mask.
The same as
voc_as_mask()
.
-
class
torchray.benchmark.datasets.
CocoDetection
(root, annFile, *args, limiter=None, **kwargs)[source]¶ Bases:
torchvision.datasets.coco.CocoDetection
COCO Detection dataset.
The data can be downloaded at http://cocodataset.org/#download.
- Parameters
limiter (int, optional) – limit the dataset to the first
limiter
images. Default:None
.
-
classes
¶ class names.
- Type
list of str
-
selection
¶ indices of the active images.
- Type
list of int
-
as_class_ids
(label)[source]¶ Convert a label to list of class IDs.
The same as
coco_as_class_ids()
.
-
as_image_size
(label)[source]¶ Convert a label to the image size.
The same as
coco_as_image_size()
.
-
as_mask
(label, class_id)[source]¶ Convert a label to a mask.
The same as
coco_as_mask()
.
-
get_image_url
(i)[source]¶ Return image url.
- Parameters
i (int) – image index.
- Returns
path to image.
- Return type
str
-
property
images
¶ paths to images.
- Type
list of str
-
torchray.benchmark.datasets.
get_dataset
(name, subset, dataset_dir=None, annotation_dir=None, transform=None, limiter=None, download=False)[source]¶ Returns a
torch.data.Dataset
.- Parameters
name (str) – name of the dataset; choose from
"imagenet"
,"voc"
or"coco"
.subset (str) – name of the dataset subset or split.
dataset_dir (str, optional) – Path to root directory containing data. Default:
None
.annotation_dir (str, optional) – Path to root directory containing annotations. Required for COCO only. Default:
None
.transform (function, optional) – input transformation function. Default:
None
.limiter (int, optional) – limit the dataset to
limiter
images. Default:None
.download (bool, optional) – If True and
name
is"voc"
, download the dataset todataset_dir
. Default:False
.
- Returns
the requested dataset.
- Return type
torch.data.Dataset
Reference models¶
This module allows obtaining standard models for benchmarking attribution
methods. The models can be obtained via the function get_model()
.
The function can edit models slightly to make them compatible with benchmarks. Optional modifications include
Converting a model to fully-convolutional (by replacing linear layers with equivalent convolutional layers.)
Adding a Global Average Pooling (GAP) layer at the end, so that a fully-convolutional model can still work as an image classifier.
For the pointing game, we support the VGG16 and ResNet50 models fine-tuned on the PASCAL VOC 2017 and COCO 2014 classification tasks from the paper [EBP] that introduced this test. These models are converted from the original Caffe implementation and reproduce the results in [EBP].
-
torchray.benchmark.models.
get_model
(arch='vgg16', dataset='voc', convert_to_fully_convolutional=False)[source]¶ Return a reference model for the specified architecture and dataset.
The model is returned in evaluation mode.
- Parameters
arch (str, optional) – name of architecture. If
dataset
contains"imagenet"
, alltorchvision.models
architectures are supported; otherwise, only “vgg16” and “resnet50” are currently supported). Default:'vgg16'
.dataset (str, optional) – name of dataset, should contain
'imagenet'
,'voc'
, or'coco'
. Default:'voc'
.convert_to_fully_convolutional (bool, optional) – If True, convert the model to be fully convolutional. Default: False.
- Returns
model.
- Return type
torch.nn.Module
-
torchray.benchmark.models.
get_transform
(dataset='imagenet', size=224)[source]¶ Returns a composition of standard pre-processing transformations for feeding models. For non-ImageNet datasets, the transforms are for models converted from Caffe (i.e., Caffe pre-processing).
- Parameters
dataset (str) – name of dataset, should contain either
'imagenet'
,'coco'
or'voc'
(default:'imagenet'
).size (sequence or int) – desired output size (see
torchvision.transforms.Resize
for more details).
- Returns
transform.
- Return type
torchvision.Transform
-
torchray.benchmark.models.
replace_module
(model, module_name, new_module)[source]¶ Replace a
torch.nn.Module
with another one in a model.- Parameters
model (
torch.nn.Module
) – model in which to find and replace the module with the namemodule_name
withnew_module
.module_name (str) – path of module to replace in the model as a string, with
'.'
denoting membership in another module. For example,'features.11'
in AlexNet (given bytorchvision.models.alexnet.alexnet()
) refers to the 11th module in the'features'
module, that is, thetorch.nn.ReLU
module after the last conv layer in AlexNet.new_module (
torch.nn.Module
) – replacement module.
Logging with MongoDB¶
This module provides function that to be log information (e.g., benchmark results) to a MongoDB database.
See examples.standard_suite
for an example of how to use MongoDB for
logging benchmark results.
To start a MongoDB server, use
$ python -m torchray.benchmark.server
-
torchray.benchmark.logging.
mongo_connect
(database)[source]¶ Connect to MongoDB server and and return a
pymongo.database.Database
object.- Parameters
database (str) – name of database.
- Returns
database.
- Return type
pymongo.database.Database
-
torchray.benchmark.logging.
mongo_save
(database, collection_key, id_key, data)[source]¶ Save results to MongoDB database.
- Parameters
database (
pymongo.database.Database
) – MongoDB database to save results to.collection_key (str) – name of collection.
id_key (str) – id key with which to store
data
.data (
bson.binary.Binary
or dict) – data to store indb
.
-
torchray.benchmark.logging.
mongo_load
(database, collection_key, id_key)[source]¶ Load data from MongoDB database.
- Parameters
database (
pymongo.database.Database
) – MongoDB database to save results to.collection_key (str) – name of collection.
id_key (str) – id key to look up data.
- Returns
retrieved data (returns None if no data with
id_key
is found).
-
torchray.benchmark.logging.
data_to_mongo
(data)[source]¶ Prepare data to be stored in a MongoDB database.
- Parameters
data (dict,
torch.Tensor
, ornp.ndarray
) – data to prepare for storage in a MongoDB dataset (if dict, items are recursively prepared for storage). If the underlying data is nottorch.Tensor
ornp.ndarray
, thendata
is returned as is.- Returns
correctly formatted data to store in a MongoDB database.
- Return type
bson.binary.Binary
or dict ofbson.binary.Binary
-
torchray.benchmark.logging.
data_from_mongo
(mongo_data, map_location=None)[source]¶ Decode data stored in a MongoDB database.
- Parameters
mongo_data (
bson.binary.Binary
or dict) – data to decode (if dict, items are recursively decoded). If the underlying data type is not :class:torch.Tensor or something stored usingpickle
, thenmongo_data
is returned as is.map_location (function,
torch.device
, str or dict) – where to remap storage locations (seetorch.load()
for more details). Default:None
.
- Returns
decoded data.
-
torchray.benchmark.logging.
last_lines
(string, num_lines)[source]¶ Extract the last few lines from a string.
The function extracts the last attr:n lines from the string attr:str. If attr:n is a negative number, then it extracts the first lines instead. It also skips lines beginning with
'Figure('
.- Parameters
string (str) – string.
num_lines (int) – number of lines to extract.
- Returns
substring.
- Return type
str