This example shows how to train a network that classifies handwritten digits with a custom learning rate schedule.
Load Training Data
Load the digits data as an image datastore using the imageDatastore
function and specify the folder containing the image data.
Partition the data into training and validation sets. Set aside 10% of the data for validation using the splitEachLabel
function.
The network used in this example requires input images of size 28-by-28-by-1. To automatically resize the training images, use an augmented image datastore. Specify additional augmentation operations to perform on the training images: randomly translate the images up to 5 pixels in the horizontal and vertical axes. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.
To automatically resize the validation images without performing further data augmentation, use an augmented image datastore without specifying any additional preprocessing operations.
Determine the number of classes in the training data.
Define Network
Define the network for image classification.
Create a dlnetwork
object from the layer graph.
dlnet =
dlnetwork with properties:
Layers: [12×1 nnet.cnn.layer.Layer]
Connections: [11×2 table]
Learnables: [14×3 table]
State: [6×3 table]
InputNames: {'input'}
OutputNames: {'softmax'}
Define Model Gradients Function
Create the function modelGradients
, listed at the end of the example, that takes a dlnetwork
object, a mini-batch of input data with corresponding labels and returns the gradients of the loss with respect to the learnable parameters in the network and the corresponding loss.
Specify Training Options
Train for ten epochs with a mini-batch size of 128.
Specify the options for SGDM optimization. Specify an initial learn rate of 0.01 with a decay of 0.01, and momentum 0.9.
Train Model
Create a minibatchqueue
object that processes and manages mini-batches of images during training. For each mini-batch:
Use the custom mini-batch preprocessing function preprocessMiniBatch
(defined at the end of this example) to convert the labels to one-hot encoded variables.
Format the image data with the dimension labels 'SSCB'
(spatial, spatial, channel, batch). By default, the minibatchqueue
object converts the data to dlarray
objects with underlying type single
. Do not add a format to the class labels.
Train on a GPU if one is available. By default, the minibatchqueue
object converts each output to a gpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.
Initialize the training progress plot.
Initialize the velocity parameter for the SGDM solver.
Train the network using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. For each mini-batch:
Evaluate the model gradients, state, and loss using the dlfeval
and modelGradients
functions and update the network state.
Determine the learning rate for the time-based decay learning rate schedule.
Update the network parameters using the sgdmupdate
function.
Display the training progress.
Test Model
Test the classification accuracy of the model by comparing the predictions on the validation set with the true labels.
After training, making predictions on new data does not require the labels. Create minibatchqueue
object containing only the predictors of the test data:
To ignore the labels for testing, set the number of outputs of the mini-batch queue to 1.
Specify the same mini-batch size used for training.
Preprocess the predictors using the preprocessMiniBatchPredictors
function, listed at the end of the example.
For the single output of the datastore, specify the mini-batch format 'SSCB'
(spatial, spatial, channel, batch).
Loop over the mini-batches and classify the images using modelPredictions
function, listed at the end of the example.
Evaluate the classification accuracy.
Model Gradients Function
The modelGradients
function takes a dlnetwork
object dlnet
, a mini-batch of input data dlX
with corresponding labels Y
and returns the gradients of the loss with respect to the learnable parameters in dlnet
, the network state, and the loss. To compute the gradients automatically, use the dlgradient
function.
Model Predictions Function
The modelPredictions
function takes a dlnetwork
object dlnet
, a minibatchqueue
of input data mbq
, and the network classes, and computes the model predictions by iterating over all data in the minibatchqueue
object. The function uses the onehotdecode
function to find the predicted class with the highest score.
Mini Batch Preprocessing Function
The preprocessMiniBatch
function preprocesses a mini-batch of predictors and labels using the following steps:
Preprocess the images using the preprocessMiniBatchPredictors
function.
Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension.
One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.
Mini-Batch Predictors Preprocessing Function
The preprocessMiniBatchPredictors
function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenate into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension.