trainingOptions

Options for training deep learning neural network

Syntax

options = trainingOptions(solverName)
options = trainingOptions(solverName,Name,Value)

Description

example

options = trainingOptions(solverName) returns training options for the optimizer specified by solverName. To train a network, use the training options as an input argument to the trainNetwork function.

example

options = trainingOptions(solverName,Name,Value) returns training options with additional options specified by one or more name-value pair arguments.

Examples

collapse all

Create a set of options for training a network using stochastic gradient descent with momentum. Reduce the learning rate by a factor of 0.2 every 5 epochs. Set the maximum number of epochs for training to 20, and use a mini-batch with 64 observations at each iteration. Turn on the training progress plot.

options = trainingOptions('sgdm', ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.2, ...
    'LearnRateDropPeriod',5, ...
    'MaxEpochs',20, ...
    'MiniBatchSize',64, ...
    'Plots','training-progress')
options = 
  TrainingOptionsSGDM with properties:

                     Momentum: 0.9000
             InitialLearnRate: 0.0100
    LearnRateScheduleSettings: [1x1 struct]
             L2Regularization: 1.0000e-04
      GradientThresholdMethod: 'l2norm'
            GradientThreshold: Inf
                    MaxEpochs: 20
                MiniBatchSize: 64
                      Verbose: 1
             VerboseFrequency: 50
               ValidationData: []
          ValidationFrequency: 50
           ValidationPatience: Inf
                      Shuffle: 'once'
               CheckpointPath: ''
         ExecutionEnvironment: 'auto'
                   WorkerLoad: []
                    OutputFcn: []
                        Plots: 'training-progress'
               SequenceLength: 'longest'
         SequencePaddingValue: 0
         DispatchInBackground: 0

When you train networks for deep learning, it is often useful to monitor the training progress. By plotting various metrics during training, you can learn how the training is progressing. For example, you can determine if and how quickly the network accuracy is improving, and whether the network is starting to overfit the training data.

When you specify 'training-progress' as the 'Plots' value in trainingOptions and start network training, trainNetwork creates a figure and displays training metrics at every iteration. Each iteration is an estimation of the gradient and an update of the network parameters. If you specify validation data in trainingOptions, then the figure shows validation metrics each time trainNetwork validates the network. The figure plots the following:

  • Training accuracy — Classification accuracy on each individual mini-batch.

  • Smoothed training accuracy — Smoothed training accuracy, obtained by applying a smoothing algorithm to the training accuracy. It is less noisy than the unsmoothed accuracy, making it easier to spot trends.

  • Validation accuracy — Classification accuracy on the entire validation set (specified using trainingOptions).

  • Training loss, smoothed training loss, and validation loss The loss on each mini-batch, its smoothed version, and the loss on the validation set, respectively. If the final layer of your network is a classificationLayer, then the loss function is the cross entropy loss. For more information about loss functions for classification and regression problems, see Output Layers.

For regression networks, the figure plots the root mean square error (RMSE) instead of the accuracy.

The figure marks each training Epoch using a shaded background. An epoch is a full pass through the entire data set.

During training, you can stop training and return the current state of the network by clicking the stop button in the top-right corner. For example, you might want to stop training when the accuracy of the network reaches a plateau and it is clear that the accuracy is no longer improving. After you click the stop button, it can take a while for the training to complete. Once training is complete, trainNetwork returns the trained network.

When training finishes, view the Results showing the final validation accuracy and the reason that training finished. The final validation metrics are labeled Final in the plots. If your network contains batch normalization layers, then the final validation metrics are often different from the validation metrics evaluated during training. This is because batch normalization layers in the final network perform different operations than during training.

On the right, view information about the training time and settings. To learn more about training options, see Set Up Parameters and Train Convolutional Neural Network.

Plot Training Progress During Training

Train a network and plot the training progress during training.

Load the training data, which contains 5000 images of digits. Set aside 1000 of the images for network validation.

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

Construct a network to classify the digit image data.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Specify options for network training. To validate the network at regular intervals during training, specify validation data. Choose the 'ValidationFrequency' value so that the network is validated about once per epoch. To plot training progress during training, specify 'training-progress' as the 'Plots' value.

options = trainingOptions('sgdm', ...
    'MaxEpochs',8, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',30, ...
    'Verbose',false, ...
    'Plots','training-progress');

Train the network.

net = trainNetwork(XTrain,YTrain,layers,options);

Input Arguments

collapse all

Solver for training network, specified as one of the following:

  • 'sgdm' — Use the stochastic gradient descent with momentum (SGDM) optimizer. You can specify the momentum value using the 'Momentum' name-value pair argument.

  • 'rmsprop'— Use the RMSProp optimizer. You can specify the decay rate of the squared gradient moving average using the 'SquaredGradientDecayFactor' name-value pair argument.

  • 'adam'— Use the Adam optimizer. You can specify the decay rates of the gradient and squared gradient moving averages using the 'GradientDecayFactor' and 'SquaredGradientDecayFactor' name-value pair arguments, respectively.

For more information about the different solvers, see Stochastic Gradient Descent.

Name-Value Pair Arguments

Specify optional comma-separated pairs of Name,Value arguments. Name is the argument name and Value is the corresponding value. Name must appear inside quotes. You can specify several name and value pair arguments in any order as Name1,Value1,...,NameN,ValueN.

Example: 'InitialLearnRate',0.03,'L2Regularization',0.0005,'LearnRateSchedule','piecewise' specifies the initial learning rate as 0.03 and theL2 regularization factor as 0.0005, and instructs the software to drop the learning rate every given number of epochs by multiplying with a certain factor.

Plots and Display

collapse all

Plots to display during network training, specified as the comma-separated pair consisting of 'Plots' and one of the following:

  • 'none' — Do not display plots during training.

  • 'training-progress'— Plot training progress. The plot shows mini-batch loss and accuracy, validation loss and accuracy, and additional information on the training progress. The plot has a stop button in the top-right corner. Click the button to stop training and return the current state of the network. For more information on the training progress plot, see Monitor Deep Learning Training Progress.

Example: 'Plots','training-progress'

Indicator to display training progress information in the command window, specified as the comma-separated pair consisting of 'Verbose' and either 1 (true) or 0 (false).

The verbose output displays the following information:

Classification Networks

FieldDescription
EpochEpoch number. An epoch corresponds to a full pass of the data.
IterationIteration number. An iteration corresponds to a mini-batch.
Time ElapsedTime elapsed in hours, minutes, and seconds.
Mini-batch AccuracyClassification accuracy on the mini-batch.
Validation AccuracyClassification accuracy on the validation data. If you do not specify validation data, then the function does not display this field.
Mini-batch LossLoss on the mini-batch. If the output layer is a ClassificationOutputLayer object, then the loss is the cross entropy loss for multi-class classification problems with mutually exclusive classes.
Validation LossLoss on the validation data. If the output layer is a ClassificationOutputLayer object, then the loss is the cross entropy loss for multi-class classification problems with mutually exclusive classes. If you do not specify validation data, then the function does not display this field.
Base Learning RateBase learning rate. The software multiplies the learn rate factors of the layers by this value.

Regression Networks

FieldDescription
EpochEpoch number. An epoch corresponds to a full pass of the data.
IterationIteration number. An iteration corresponds to a mini-batch.
Time ElapsedTime elapsed in hours, minutes, and seconds.
Mini-batch RMSERoot-mean-squared-error (RMSE) on the mini-batch.
Validation RMSERMSE on the validation data. If you do not specify validation data, then the software does not display this field.
Mini-batch LossLoss on the mini-batch. If the output layer is a RegressionOutputLayer object, then the loss is the half-mean-squared-error.
Validation LossLoss on the validation data. If the output layer is a RegressionOutputLayer object, then the loss is the half-mean-squared-error. If you do not specify validation data, then the software does not display this field.
Base Learning RateBase learning rate. The software multiplies the learn rate factors of the layers by this value.

To specify validation data, use the 'ValidationData' name-value pair.

Example: 'Verbose',false

Frequency of verbose printing, which is the number of iterations between printing to the command window, specified as the comma-separated pair consisting of 'VerboseFrequency' and a positive integer. This option only has an effect when the 'Verbose' value equals true.

If you validate the network during training, then trainNetwork also prints to the command window every time validation occurs.

Example: 'VerboseFrequency',100

Mini-Batch Options

collapse all

Maximum number of epochs to use for training, specified as the comma-separated pair consisting of 'MaxEpochs' and a positive integer.

An iteration is one step taken in the gradient descent algorithm towards minimizing the loss function using a mini-batch. An epoch is the full pass of the training algorithm over the entire training set.

Example: 'MaxEpochs',20

Size of the mini-batch to use for each training iteration, specified as the comma-separated pair consisting of 'MiniBatchSize' and a positive integer. A mini-batch is a subset of the training set that is used to evaluate the gradient of the loss function and update the weights. See Stochastic Gradient Descent.

Example: 'MiniBatchSize',256

Option for data shuffling, specified as the comma-separated pair consisting of 'Shuffle' and one of the following:

  • 'once' — Shuffle the training and validation data once before training.

  • 'never' — Do not shuffle the data.

  • 'every-epoch' — Shuffle the training data before each training epoch, and shuffle the validation data before each network validation. If the mini-batch size does not evenly divide the number of training samples, then trainNetwork discards the training data that does not fit into the final complete mini-batch of each epoch. To avoid discarding the same data every epoch, set the 'Shuffle' value to 'every-epoch'.

Example: 'Shuffle','every-epoch'

Validation

collapse all

Data to use for validation during training, specified as an image datastore, a datastore that returns data in a two-column table or two-column cell array, a table, or a cell array. The format of the validation data depends on the type of task and correspond to valid inputs to the trainNetwork function.

Image Data

For image data, specify the validation data as one of the following:

InputDescriptionMore Information
Image datastore

ImageDatastore object with categorical labels.

imds argument of trainNetwork
DatastoreDatastore that returns data in a two-column table or two-column cell array, where the two columns specify the network inputs and expected responses, respectively.ds argument of trainNetwork
Table

Table, where the first column contains either image paths or images, and the subsequent columns contain the responses.

tbl argument of trainNetwork
Cell array {X,Y}X

Numeric array of images.

X argument of trainNetwork
Y

Categorical vector of labels, matrix of numeric responses, or array of images.

Y argument of trainNetwork

Sequence and Time Series Data

For sequence and time series data, specify the validation data as one of the following:

InputDescriptionMore Information
Cell array {C,Y}C

Cell array of sequences or time series data.

C argument of trainNetwork
Y

Categorical vector of labels, cell array of categorical sequences, matrix of numeric responses, or cell array of numeric sequences.

Y argument of trainNetwork
Table

Table containing absolute or relative file paths to a MAT files containing sequence or time series data.

tbl argument of trainNetwork

During training, trainNetwork calculates the validation accuracy and validation loss on the validation data. To specify the validation frequency, use the 'ValidationFrequency' name-value pair argument. You can also use the validation data to stop training automatically when the validation loss stops decreasing. To turn on automatic validation stopping, use the 'ValidationPatience' name-value pair argument.

If your network has layers that behave differently during prediction than during training (for example, dropout layers), then the validation accuracy can be higher than the training (mini-batch) accuracy.

The validation data is shuffled according to the 'Shuffle' value. If the 'Shuffle' value equals 'every-epoch', then the validation data is shuffled before each network validation.

Frequency of network validation in number of iterations, specified as the comma-separated pair consisting of 'ValidationFrequency' and a positive integer.

The 'ValidationFrequency' value is the number of iterations between evaluations of validation metrics. To specify validation data, use the 'ValidationData' name-value pair argument.

Example: 'ValidationFrequency',20

Patience of validation stopping of network training, specified as the comma-separated pair consisting of 'ValidationPatience' and a positive integer or Inf.

The 'ValidationPatience' value is the number of times that the loss on the validation set can be larger than or equal to the previously smallest loss before network training stops. To turn on automatic validation stopping, specify a positive integer as the 'ValidationPatience' value. If you use the default value of Inf, then the training stops after the maximum number of epochs. To specify validation data, use the 'ValidationData' name-value pair argument.

Example: 'ValidationPatience',5

Solver Options

collapse all

Initial learning rate used for training, specified as the comma-separated pair consisting of 'InitialLearnRate' and a positive scalar. The default value is 0.01 for the 'sgdm' solver and 0.001 for the 'rmsprop' and 'adam' solvers. If the learning rate is too low, then training takes a long time. If the learning rate is too high, then training might reach a suboptimal result or diverge.

Example: 'InitialLearnRate',0.03

Data Types: single | double

Option for dropping the learning rate during training, specified as the comma-separated pair consisting of 'LearnRateSchedule' and one of the following:

  • 'none' — The learning rate remains constant throughout training.

  • 'piecewise' — The software updates the learning rate every certain number of epochs by multiplying with a certain factor. Use the LearnRateDropFactor name-value pair argument to specify the value of this factor. Use the LearnRateDropPeriod name-value pair argument to specify the number of epochs between multiplications.

Example: 'LearnRateSchedule','piecewise'

Number of epochs for dropping the learning rate, specified as the comma-separated pair consisting of 'LearnRateDropPeriod' and a positive integer. This option is valid only when the value of LearnRateSchedule is 'piecewise'.

The software multiplies the global learning rate with the drop factor every time the specified number of epochs passes. Specify the drop factor using the LearnRateDropFactor name-value pair argument.

Example: 'LearnRateDropPeriod',3

Factor for dropping the learning rate, specified as the comma-separated pair consisting of 'LearnRateDropFactor' and a scalar from 0 to 1. This option is valid only when the value of LearnRateSchedule is 'piecewise'.

LearnRateDropFactor is a multiplicative factor to apply to the learning rate every time a certain number of epochs passes. Specify the number of epochs using the LearnRateDropPeriod name-value pair argument.

Example: 'LearnRateDropFactor',0.1

Data Types: single | double

Factor for L2 regularization (weight decay), specified as the comma-separated pair consisting of 'L2Regularization' and a nonnegative scalar. For more information, see L2 Regularization.

You can specify a multiplier for the L2 regularization for network layers with learnable parameters. For more information, see Set Up Parameters in Convolutional and Fully Connected Layers.

Example: 'L2Regularization',0.0005

Data Types: single | double

Contribution of the parameter update step of the previous iteration to the current iteration of stochastic gradient descent with momentum, specified as the comma-separated pair consisting of 'Momentum' and a scalar from 0 to 1. A value of 0 means no contribution from the previous step, whereas a value of 1 means maximal contribution from the previous step.

To specify the 'Momentum' value, you must set solverName to be 'sgdm'. The default value works well for most problems. For more information about the different solvers, see Stochastic Gradient Descent.

Example: 'Momentum',0.95

Data Types: single | double

Decay rate of gradient moving average for the Adam solver, specified as the comma-separated pair consisting of 'GradientDecayFactor' and a scalar from 0 to 1. The gradient decay rate is denoted by β1 in [4].

To specify the 'GradientDecayFactor' value, you must set solverName to be 'adam'. The default value works well for most problems. For more information about the different solvers, see Stochastic Gradient Descent.

Example: 'GradientDecayFactor',0.95

Data Types: single | double

Decay rate of squared gradient moving average for the Adam and RMSProp solvers, specified as the comma-separated pair consisting of 'SquaredGradientDecayFactor' and a scalar from 0 to 1. The squared gradient decay rate is denoted by β2 in [4].

To specify the 'SquaredGradientDecayFactor' value, you must set solverName to be 'adam' or 'rmsprop'. Typical values of the decay rate are 0.9, 0.99, and 0.999, corresponding to averaging lengths of 10, 100, and 1000 parameter updates, respectively. For more information about the different solvers, see Stochastic Gradient Descent.

Example: 'SquaredGradientDecayFactor',0.99

Data Types: single | double

Denominator offset for Adam and RMSProp solvers, specified as the comma-separated pair consisting of 'Epsilon' and a positive scalar. The solver adds the offset to the denominator in the network parameter updates to avoid division by zero.

To specify the 'Epsilon' value, you must set solverName to be 'adam' or 'rmsprop'. The default value works well for most problems. For more information about the different solvers, see Stochastic Gradient Descent.

Example: 'Epsilon',1e-6

Data Types: single | double

Gradient Clipping

collapse all

Gradient threshold, specified as the comma-separated pair consisting of 'GradientThreshold' and Inf or a positive scalar. If the gradient exceeds the value of GradientThreshold, then the gradient is clipped according to GradientThresholdMethod.

Example: 'GradientThreshold',6

Gradient threshold method used to clip gradient values that exceed the gradient threshold, specified as the comma-separated pair consisting of 'GradientThresholdMethod' and one of the following:

  • 'l2norm' — If the L2 norm of the gradient of a learnable parameter is larger than GradientThreshold, then scale the gradient so that the L2 norm equals GradientThreshold.

  • 'global-l2norm' — If the global L2 norm, L, is larger than GradientThreshold, then scale all gradients by a factor of GradientThreshold/L. The global L2 norm considers all learnable parameters.

  • 'absolute-value' — If the absolute value of an individual partial derivative in the gradient of a learnable parameter is larger than GradientThreshold, then scale the partial derivative to have magnitude equal to GradientThreshold and retain the sign of the partial derivative.

For more information, see Gradient Clipping.

Example: 'GradientThresholdMethod','global-l2norm'

Sequence Options

collapse all

Option to pad, truncate, or split input sequences, specified as one of the following:

  • 'longest' — Pad sequences in each mini-batch to have the same length as the longest sequence. This option does not discard any data, though padding can introduce noise to the network.

  • 'shortest' — Truncate sequences in each mini-batch to have the same length as the shortest sequence. This option ensures that no padding is added, at the cost of discarding data.

  • Positive integer — For each mini-batch, pad the sequences to the nearest multiple of the specified length that is greater than the longest sequence length in the mini-batch, and then split the sequences into smaller sequences of the specified length. If splitting occurs, then the software creates extra mini-batches. Use this option if the full sequences do not fit in memory. Alternatively, try reducing the number of sequences per mini-batch by setting the 'MiniBatchSize' option to a lower value.

If you specify the sequence length as a positive integer, then the software processes the smaller sequences in consecutive iterations. The network updates the network state between the split sequences.

The software pads and truncates the sequences on the right. To learn more about the effect of padding, truncating, and splitting the input sequences, see Sequence Padding, Truncation, and Splitting.

Example: 'SequenceLength','shortest'

Value by which to pad input sequences, specified as a scalar. The option is valid only when SequenceLength is 'longest' or a positive integer. Do not pad sequences with NaN, because doing so can propagate errors throughout the network.

Example: 'SequencePaddingValue',-1

Hardware Options

collapse all

Hardware resource for training network, specified as one of the following:

  • 'auto' — Use a GPU if one is available. Otherwise, use the CPU.

  • 'cpu' — Use the CPU.

  • 'gpu' — Use the GPU.

  • 'multi-gpu' — Use multiple GPUs on one machine, using a local parallel pool. If no pool is open, then the software opens one based on your default parallel settings.

  • 'parallel' — Use a local parallel pool or compute cluster. If no pool is open, then the software opens one using the default cluster profile. If the pool has access to GPUs, then only workers with a unique GPU perform training computation. If the pool does not have GPUs, then the training takes place on all cluster CPUs.

For more information on when to use the different execution environments, see Scale Up Deep Learning in Parallel and in the Cloud.

GPU, multi-GPU, and parallel options require Parallel Computing Toolbox™.To use a GPU for deep learning, you must also have a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher. If you choose one of these options and Parallel Computing Toolbox or a suitable GPU is not available, then the software returns an error.

To see an improvement in performance when training in parallel, try increasing the MiniBatchSize training option to offset the communication overhead.

Training long short-term memory networks supports single CPU and GPU training only.

Certain input datastores support background dispatch with parallel or multi-GPU execution environments. These datastores are: augmentedImageDatastore, pixelLabelImageDatastore, denoisingImageDatastore, and randomPatchExtractionDatastore. Other input datastores do not support 'parallel' or 'multi-gpu' values of ExecutionEnvironment with DispatchInBackground.

Example: 'ExecutionEnvironment','cpu'

Parallel worker load division between GPUs or CPUs, specified as the comma-separated pair consisting of 'WorkerLoad' and one of the following:

  • Scalar from 0 to 1 — Fraction of workers on each machine to use for network training computation. If you train the network using data in a mini-batch datastore with background dispatch enabled, then the remaining workers fetch and preprocess data in the background.

  • Positive integer — Number of workers on each machine to use for network training computation. If you train the network using data in a mini-batch datastore with background dispatch enabled, then the remaining workers fetch and preprocess data in the background.

  • Numeric vector — Network training load for each worker in the parallel pool. For a vector W, worker i gets a fraction W(i)/sum(W) of the work (number of examples per mini-batch). If you train a network using data in a mini-batch datastore with background dispatch enabled, then you can assign a worker load of 0 to use that worker for fetching data in the background. The specified vector must contain one value per worker in the parallel pool.

If the parallel pool has access to GPUs, then workers without a unique GPU are never used for training computation. The default for pools with GPUs is to use all workers with a unique GPU for training computation, and the remaining workers for background dispatch. If the pool does not have access to GPUs and CPUs are used for training, then the default is to use one worker per machine for background data dispatch.

Use asynchronous prefetch queuing to read training data from datastores, specified as false or true. Asynchronous prefetch queuing requires Parallel Computing Toolbox.

Certain input datastores support background dispatch with parallel or multi-GPU execution environments. These datastores are: augmentedImageDatastore, pixelLabelImageDatastore, denoisingImageDatastore, and randomPatchExtractionDatastore. Other input datastores do not support 'parallel' or 'multi-gpu' values of ExecutionEnvironment with DispatchInBackground.

Checkpoints

collapse all

Path for saving the checkpoint networks, specified as the comma-separated pair consisting of 'CheckpointPath' and a character vector.

  • If you do not specify a path (that is, you use the default ''), then the software does not save any checkpoint networks.

  • If you specify a path, then trainNetwork saves checkpoint networks to this path after every epoch and assigns a unique name to each network. You can then load any checkpoint network and resume training from that network.

    If the folder does not exist, then you must first create it before specifying the path for saving the checkpoint networks. If the path you specify does not exist, then trainingOptions returns an error.

For more information about saving network checkpoints, see Save Checkpoint Networks and Resume Training.

Example: 'CheckpointPath','C:\Temp\checkpoint'

Data Types: char

Output functions to call during training, specified as the comma-separated pair consisting of 'OutputFcn' and a function handle or cell array of function handles. trainNetwork calls the specified functions once before the start of training, after each iteration, and once after training has finished. trainNetwork passes a structure containing information in the following fields:

FieldDescription
EpochCurrent epoch number
IterationCurrent iteration number
TimeSinceStartTime in seconds since the start of training
TrainingLossCurrent mini-batch loss
ValidationLossLoss on the validation data
BaseLearnRateCurrent base learning rate
TrainingAccuracy Accuracy on the current mini-batch (classification networks)
TrainingRMSERMSE on the current mini-batch (regression networks)
ValidationAccuracyAccuracy on the validation data (classification networks)
ValidationRMSERMSE on the validation data (regression networks)
StateCurrent training state, with a possible value of "start", "iteration", or "done"

If a field is not calculated or relevant for a certain call to the output functions, then that field contains an empty array.

You can use output functions to display or plot progress information, or to stop training. To stop training early, make your output function return true. If any output function returns true, then training finishes and trainNetwork returns the latest network. For an example showing how to use output functions, see Customize Output During Deep Learning Network Training.

Data Types: function_handle | cell

Output Arguments

collapse all

Training options, returned as a TrainingOptionsSGDM, TrainingOptionsRMSProp, or TrainingOptionsADAM object. To train a neural network, use the training options as an input argument to the trainNetwork function.

If solverName equals 'sgdm', 'rmsprop', or 'adam', then the training options are returned as a TrainingOptionsSGDM, TrainingOptionsRMSProp, or TrainingOptionsADAM object, respectively.

Algorithms

collapse all

Initial Weights and Biases

The default for the initial weights is a Gaussian distribution with a mean of 0 and a standard deviation of 0.01. The default for the initial bias value is 0. You can manually change the initialization for the weights and biases. See Specify Initial Weights and Biases in Convolutional Layer and Specify Initial Weights and Biases in Fully Connected Layer.

Stochastic Gradient Descent

The standard gradient descent algorithm updates the network parameters (weights and biases) to minimize the loss function by taking small steps at each iteration in the direction of the negative gradient of the loss,

θ+1=θαE(θ),

where is the iteration number, α>0 is the learning rate, θ is the parameter vector, and E(θ) is the loss function. In the standard gradient descent algorithm, the gradient of the loss function, E(θ), is evaluated using the entire training set, and the standard gradient descent algorithm uses the entire data set at once.

By contrast, at each iteration the stochastic gradient descent algorithm evaluates the gradient and updates the parameters using a subset of the training data. A different subset, called a mini-batch, is used at each iteration. The full pass of the training algorithm over the entire training set using mini-batches is one epoch. Stochastic gradient descent is stochastic because the parameter updates computed using a mini-batch is a noisy estimate of the parameter update that would result from using the full data set. You can specify the mini-batch size and the maximum number of epochs by using the 'MiniBatchSize' and 'MaxEpochs' name-value pair arguments, respectively.

Stochastic Gradient Descent with Momentum

The stochastic gradient descent algorithm can oscillate along the path of steepest descent towards the optimum. Adding a momentum term to the parameter update is one way to reduce this oscillation [2]. The stochastic gradient descent with momentum (SGDM) update is

θ+1=θαE(θ)+γ(θθ1),

where γ determines the contribution of the previous gradient step to the current iteration. You can specify this value using the 'Momentum' name-value pair argument. To train a neural network using the stochastic gradient descent with momentum algorithm, specify solverName as 'sgdm'. To specify the initial value of the learning rate α, use the'InitialLearnRate' name-value pair argument. You can also specify different learning rates for different layers and parameters. For more information, see Set Up Parameters in Convolutional and Fully Connected Layers.

RMSProp

Stochastic gradient descent with momentum uses a single learning rate for all the parameters. Other optimization algorithms seek to improve network training by using learning rates that differ by parameter and can automatically adapt to the loss function being optimized. RMSProp (root mean square propagation) is one such algorithm. It keeps a moving average of the element-wise squares of the parameter gradients,

v=β2v1+(1β2)[E(θ)]2

β2 is the decay rate of the moving average. Common values of the decay rate are 0.9, 0.99, and 0.999. The corresponding averaging lengths of the squared gradients equal 1/(1-β2), that is, 10, 100, and 1000 parameter updates, respectively. You can specify β2 by using the 'SquaredGradientDecayFactor' name-value pair argument. The RMSProp algorithm uses this moving average to normalize the updates of each parameter individually,

θ+1=θαE(θ)v+ϵ

where the division is performed element-wise. Using RMSProp effectively decreases the learning rates of parameters with large gradients and increases the learning rates of parameters with small gradients. ɛ is a small constant added to avoid division by zero. You can specify ɛ by using the 'Epsilon' name-value pair argument, but the default value usually works well. To use RMSProp to train a neural network, specify solverName as 'rmsprop'.

Adam

Adam (derived from adaptive moment estimation) [4] uses a parameter update that is similar to RMSProp, but with an added momentum term. It keeps an element-wise moving average of both the parameter gradients and their squared values,

m=β1m1+(1β1)E(θ)

v=β2v1+(1β2)[E(θ)]2

You can specify the β1 and β2 decay rates using the 'GradientDecayFactor' and 'SquaredGradientDecayFactor' name-value pair arguments, respectively. Adam uses the moving averages to update the network parameters as

θ+1=θαmlvl+ϵ

If gradients over many iterations are similar, then using a moving average of the gradient enables the parameter updates to pick up momentum in a certain direction. If the gradients contain mostly noise, then the moving average of the gradient becomes smaller, and so the parameter updates become smaller too. You can specify ɛ by using the 'Epsilon' name-value pair argument. The default value usually works well, but for certain problems a value as large as 1 works better. To use Adam to train a neural network, specify solverName as 'adam'. The full Adam update also includes a mechanism to correct a bias the appears in the beginning of training. For more information, see [4].

Specify the learning rate α for all optimization algorithms using the'InitialLearnRate' name-value pair argument. The effect of the learning rate is different for the different optimization algorithms, so the optimal learning rates are also different in general. You can also specify learning rates that differ by layers and by parameter. For more information, see Set Up Parameters in Convolutional and Fully Connected Layers.

Gradient Clipping

If the gradients increase in magnitude exponentially, then the training is unstable and can diverge within a few iterations. This "gradient explosion" is indicated by a training loss that goes to NaN or Inf. Gradient clipping helps prevent gradient explosion by stabilizing the training at higher learning rates and in the presence of outliers [3]. Gradient clipping enables networks to be trained faster, and does not usually impact the accuracy of the learned task.

There are two types of gradient clipping.

  • Norm-based gradient clipping rescales the gradient based on a threshold, and does not change the direction of the gradient. The 'l2norm' and 'global-l2norm' values of GradientThresholdMethod are norm-based gradient clipping methods.

  • Value-based gradient clipping clips any partial derivative greater than the threshold, which can result in the gradient arbitrarily changing direction. Value-based gradient clipping can have unpredictable behavior, but sufficiently small changes do not cause the network to diverge. The 'absolute-value' value of GradientThresholdMethod is a value-based gradient clipping method.

For examples, see Time Series Forecasting Using Deep Learning and Sequence-to-Sequence Classification Using Deep Learning.

L2 Regularization

Adding a regularization term for the weights to the loss function E(θ) is one way to reduce overfitting [1], [2]. The regularization term is also called weight decay. The loss function with the regularization term takes the form

ER(θ)=E(θ)+λΩ(w),

where w is the weight vector, λ is the regularization factor (coefficient), and the regularization function Ω(w) is

Ω(w)=12wTw.

Note that the biases are not regularized [2]. You can specify the regularization factor λ by using the 'L2Regularization' name-value pair argument. You can also specify different regularization factors for different layers and parameters. For more information, see Set Up Parameters in Convolutional and Fully Connected Layers.

The loss function that the software uses for network training includes the regularization term. However, the loss value displayed in the command window and training progress plot during training is the loss on the data only and does not include the regularization term.

Compatibility Considerations

expand all

Behavior changed in R2018b

Behavior changed in R2018b

References

[1] Bishop, C. M. Pattern Recognition and Machine Learning. Springer, New York, NY, 2006.

[2] Murphy, K. P. Machine Learning: A Probabilistic Perspective. The MIT Press, Cambridge, Massachusetts, 2012.

[3] Pascanu, R., T. Mikolov, and Y. Bengio. "On the difficulty of training recurrent neural networks". Proceedings of the 30th International Conference on Machine Learning. Vol. 28(3), 2013, pp. 1310–1318.

[4] Kingma, Diederik, and Jimmy Ba. "Adam: A method for stochastic optimization." arXiv preprint arXiv:1412.6980 (2014).

Introduced in R2016a