Create Simple Image Classification Network
This example shows how to create and train a simple convolutional neural network for deep learning classification. Convolutional neural networks are essential tools for deep learning and are especially suited for image recognition.
The example demonstrates how to:
Load image data.
Define the network architecture.
Specify training options.
Train the network.
Predict the labels of new data and calculate the classification accuracy.
For an example showing how to interactively create and train a simple image classification network, see Get Started with Image Classification.
Load Data
Unzip the digit sample data and create an image datastore. The imageDatastore
function automatically labels the images based on folder names.
unzip("DigitsData.zip") imds = imageDatastore("DigitsData", ... IncludeSubfolders=true, ... LabelSource="foldernames");
Divide the data into training and validation data sets, so that each category in the training set contains 750 images, and the validation set contains the remaining images from each label. splitEachLabel
splits the image datastore into two new datastores for training and validation.
numTrainFiles = 750;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,"randomized");
View the class names.
classNames = categories(imdsTrain.Labels)
classNames = 10x1 cell
{'0'}
{'1'}
{'2'}
{'3'}
{'4'}
{'5'}
{'6'}
{'7'}
{'8'}
{'9'}
Define Network Architecture
Define the convolutional neural network architecture. Specify the size of the images in the input layer of the network and the number of classes in the fully connected layer. Each image is 28-by-28-by-1 pixels and there are 10 classes.
inputSize = [28 28 1]; numClasses = 10; layers = [ imageInputLayer(inputSize) convolution2dLayer(5,20) batchNormalizationLayer reluLayer fullyConnectedLayer(numClasses) softmaxLayer];
For more information about deep learning layers, see List of Deep Learning Layers.
Specify Training Options
Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.
options = trainingOptions("sgdm", ... MaxEpochs=4, ... ValidationData=imdsValidation, ... ValidationFrequency=30, ... Plots="training-progress", ... Metrics="accuracy", ... Verbose=false);
Train Neural Network
Train the neural network using the trainnet
function. For classification, use cross-entropy loss. By default, the trainnet
function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU. To specify the execution environment, use the ExecutionEnvironment
training option.
net = trainnet(imdsTrain,layers,"crossentropy",options);
Test Neural Network
To test the neural network, classify the validation data and calculate the classification accuracy.
Test the neural network using the testnet
function. For single-label classification, evaluate the accuracy. The accuracy is the percentage of correct predictions. By default, the testnet
function uses a GPU if one is available. To select the execution environment manually, use the ExecutionEnvironment
argument of the testnet
function.
accuracy = testnet(net,imdsValidation,"accuracy")
accuracy = 98.9600
For next steps in deep learning, you can try using pretrained network for other tasks. Solve new classification problems on your image data with transfer learning or feature extraction. For examples, see Start Deep Learning Faster Using Transfer Learning and Train Classifiers Using Features Extracted from Pretrained Networks. To learn more about pretrained networks, see Pretrained Deep Neural Networks.
See Also
trainnet
| trainingOptions
| dlnetwork
Related Topics
- Start Deep Learning Faster Using Transfer Learning
- Get Started with Image Classification
- Try Deep Learning in 10 Lines of MATLAB Code
- Classify Image Using Pretrained Network
- Get Started with Transfer Learning
- Prepare Network for Transfer Learning Using Deep Network Designer
- Get Started with Time Series Forecasting