Code Generation for Deep Learning Networks
This example shows how to generate CUDA code for an image classification application that uses deep learning. It uses the codegen
command to generate a MEX function that runs prediction by using image classification network, ResNet.
Third-Party Prerequisites
This example generates CUDA® MEX and has the following third-party requirements.
CUDA-enabled NVIDIA® GPU and compatible driver.
For non-MEX builds such as static, dynamic libraries or executables, this example has the following additional requirements.
NVIDIA toolkit.
Environment variables for the compilers and libraries. For more information, see Third-Party Hardware (GPU Coder) and Setting Up the Prerequisite Products (GPU Coder).
Verify GPU Environment
Use the coder.checkGpuInstall
(GPU Coder) function to verify that the compilers and libraries necessary for running this example are set up correctly.
envCfg = coder.gpuEnvConfig('host'); envCfg.DeepLibTarget = 'none'; envCfg.DeepCodegen = 1; envCfg.Quiet = 1; coder.checkGpuInstall(envCfg);
Classification of Images by Using ResNet-50 network
ResNet-50 is a convolutional neural network that is 50 layers deep and can classify images into 1000 object categories. A pretrained ResNet-50 model for MATLAB® is available in the Deep Learning Toolbox™ model for ResNet-50 Network support package. Use the Add-On Explorer to download and install the support package.
[net, classNames] = imagePretrainedNetwork('resnet50');
disp(net)
dlnetwork with properties: Layers: [176×1 nnet.cnn.layer.Layer] Connections: [191×2 table] Learnables: [214×3 table] State: [106×3 table] InputNames: {'input_1'} OutputNames: {'fc1000_softmax'} Initialized: 1 View summary with summary.
resnet_predict
Entry-Point Function
The resnet_predict.m
entry-point function takes an image input and runs prediction on the image using the pretrained resnet50
deep learning network. The function uses a persistent object dlnet
to load the dlnetwork
object and reuses the persistent object for prediction on subsequent calls. This entry-point function uses the imagePretrainedNetwork
to load the dlnetwork
object and perform prediction on the input image. A dlarray
object is created within the entry-point function. The input and output to the entry-point function are of primitive datatypes. For more information, see Code Generation for dlarray (GPU Coder).
type('resnet_predict.m')
function out = resnet_predict(in) %#codegen % Copyright 2020-2024 The MathWorks, Inc. persistent dlnet; dlIn = dlarray(in, 'SSC'); if isempty(dlnet) % Call the function resnet50 that returns a dlnetwork object % for ResNet-50 model. dlnet = imagePretrainedNetwork('resnet50'); end dlOut = predict(dlnet, dlIn); out = extractdata(dlOut); end
Run MEX Code Generation
To generate CUDA code for the resnet_predict.m
entry-point function, create a GPU code configuration object for a MEX target. Use the coder.DeepLearningConfig
(GPU Coder) function to create a deep learning code configuration object and assign it to the DeepLearningConfig
property of the GPU code configuration object. Run the codegen
command and specify an input size of 224-by-224-by-3, which is the value corresponds to the input layer size of the network.
cfg = coder.gpuConfig('mex'); dlcfg = coder.DeepLearningConfig(TargetLibrary = "none"); cfg.DeepLearningConfig = dlcfg; codegen -config cfg resnet_predict -args {ones(224,224,3,'single')} -report
Code generation successful: View report
Run Genearted MEX
Call resnet_predict_mex
on the input image.
im = imread('peppers.png');
im = imresize(im, [224,224]);
predict_scores = resnet_predict_mex(single(im));
Map the Prediction Scores to Labels and Display Output
Get the top five prediction scores and their labels.
[scores,indx] = sort(predict_scores, 'descend'); classNamesTop = classNames(indx(1:5)); h = figure; h.Position(3) = 2*h.Position(3); ax1 = subplot(1,2,1); ax2 = subplot(1,2,2); image(ax1,im); barh(ax2,scores(5:-1:1)) xlabel(ax2,'Probability') yticklabels(ax2,classNamesTop(5:-1:1)) ax2.YAxisLocation = 'right'; sgtitle('Top Five Predictions That Use ResNet-50')
Clear the static network object that was loaded in memory.
clear resnet_predict_mex;