Source code for torchray.benchmark

r"""This script provides a few functions for getting and plotting example data.
"""
import os
import torchvision
from matplotlib import pyplot as plt

from .datasets import *  # noqa
from .models import *  # noqa


[docs]def get_example_data(arch='vgg16', shape=224): """Get example data to demonstrate visualization techniques. Args: arch (str, optional): name of torchvision.models architecture. Default: ``'vgg16'``. shape (int or tuple of int, optional): shape to resize input image to. Default: ``224``. Returns: (:class:`torch.nn.Module`, :class:`torch.Tensor`, int, int): a tuple containing - a convolutional neural network model in evaluation mode. - a sample input tensor image. - the ImageNet category id of an object in the image. - the ImageNet category id of another object in the image. """ # Get a network pre-trained on ImageNet. model = torchvision.models.__dict__[arch](pretrained=True) # Switch to eval mode to make the visualization deterministic. model.eval() # We do not need grads for the parameters. for param in model.parameters(): param.requires_grad_(False) # Download an example image from wikimedia. import requests from io import BytesIO from PIL import Image url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/7/7f/Arthur_Heyer_-_Dog_and_Cats.jpg/592px-Arthur_Heyer_-_Dog_and_Cats.jpg' response = requests.get(url) img = Image.open(BytesIO(response.content)) # Pre-process the image and convert into a tensor transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(shape), torchvision.transforms.CenterCrop(shape), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) x = transform(img).unsqueeze(0) # bulldog category id. category_id_1 = 245 # persian cat category id. category_id_2 = 285 # Move model and input to device. from torchray.utils import get_device dev = get_device() model = model.to(dev) x = x.to(dev) return model, x, category_id_1, category_id_2
[docs]def plot_example(input, saliency, method, category_id, show_plot=False, save_path=None): """Plot an example. Args: input (:class:`torch.Tensor`): 4D tensor containing input images. saliency (:class:`torch.Tensor`): 4D tensor containing saliency maps. method (str): name of saliency method. category_id (int): ID of ImageNet category. show_plot (bool, optional): If True, show plot. Default: ``False``. save_path (str, optional): Path to save figure to. Default: ``None``. """ from torchray.utils import imsc from torchray.benchmark.datasets import IMAGENET_CLASSES if isinstance(category_id, int): category_id = [category_id] batch_size = len(input) plt.clf() for i in range(batch_size): class_i = category_id[i % len(category_id)] plt.subplot(batch_size, 2, 1 + 2 * i) imsc(input[i]) plt.title('input image', fontsize=8) plt.subplot(batch_size, 2, 2 + 2 * i) imsc(saliency[i], interpolation='none') plt.title('{} for category {} ({})'.format( method, IMAGENET_CLASSES[class_i], class_i), fontsize=8) # Save figure if path is specified. if save_path: save_dir = os.path.dirname(os.path.abspath(save_path)) # Create directory if necessary. if not os.path.exists(save_dir): os.makedirs(save_dir) ext = os.path.splitext(save_path)[1].strip('.') plt.savefig(save_path, format=ext, bbox_inches='tight') # Show plot if desired. if show_plot: plt.show()