Grad-CAM Reveals the Why Behind Deep Learning Decisions
This example shows how to use the gradient-weighted class activation mapping (Grad-CAM) technique to understand why a deep learning network makes its classification decisions. Grad-CAM, invented by Selvaraju and coauthors [1], uses the gradient of the classification score with respect to the convolutional features determined by the network in order to understand which parts of the image are most important for classification. This example uses the GoogLeNet pretrained network for images.
Grad-CAM is a generalization of the class activation mapping (CAM) technique. For activation mapping techniques on live webcam data, see Investigate Network Predictions Using Class Activation Mapping. Grad-CAM can also be applied to nonclassification examples such as regression or semantic segmentation. For an example showing how to use Grad-CAM to investigate the predictions of a semantic segmentation network, see Explore Semantic Segmentation Network Using Grad-CAM.
Load Pretrained Network
Load the GoogLeNet network.
[net,classNames] = imagePretrainedNetwork("googlenet");
Classify Image
Read the GoogLeNet image size.
inputSize = net.Layers(1).InputSize(1:2);
Load sherlock.jpg
., an image of a golden retriever included with this example.
img = imread("sherlock.jpg");
Resize the image to the network input dimensions.
img = imresize(img,inputSize);
For single observation input, make predictions using the predict
function. To make predictions using the GPU, first convert the data to gpuArray
. Making predictions on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
if canUseGPU X = gpuArray(img); end scores = predict(net,single(img)); Y = scores2label(scores,classNames); imshow(img); title(Y);
GoogLeNet correctly classifies the image as a golden retriever. But why? What characteristics of the image cause the network to make this classification?
Grad-CAM Explains Why
The Grad-CAM technique utilizes the gradients of the classification score with respect to the final convolutional feature map, to identify the parts of an input image that most impact the classification score. The places where this gradient is large are exactly the places where the final score depends most on the data.
The gradCAM
function computes the importance map by taking the derivative of the reduction layer output for a given class with respect to a convolutional feature map. For classification tasks, the gradCAM
function automatically selects suitable layers to compute the importance map for. You can also specify the layers with the 'ReductionLayer'
and 'FeatureLayer'
name-value arguments.
Compute the Grad-CAM map.
channel = find(Y == categorical(classNames)); map = gradCAM(net,img,channel);
Show the Grad-CAM map on top of the image by using an 'AlphaData'
value of 0.5. The 'jet'
colormap has deep blue as the lowest value and deep red as the highest.
imshow(img); hold on; imagesc(map,'AlphaData',0.5); colormap jet hold off; title("Grad-CAM");
Clearly, the upper face and ear of the dog have the greatest impact on the classification.
For a different approach to investigating the reasons for deep network classifications, see occlusionSensitivity
and imageLIME
.
References
[1] Selvaraju, R. R., M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra. "Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization." In IEEE International Conference on Computer Vision (ICCV), 2017, pp. 618–626. Available at Grad-CAM
on the Computer Vision Foundation Open Access website.
See Also
gradCAM
| imageLIME
| occlusionSensitivity
| deepDreamImage
Related Topics
- Interpret Deep Learning Time-Series Classifications Using Grad-CAM
- Explore Semantic Segmentation Network Using Grad-CAM
- Investigate Network Predictions Using Class Activation Mapping
- Deep Learning Visualization Methods
- Explore Network Predictions Using Deep Learning Visualization Techniques
- Understand Network Predictions Using LIME