Main Content

Compare Layer Weight Initializers

This example shows how to train deep learning networks with different weight initializers.

When training a deep learning network, the initialization of layer weights and biases can have a big impact on how well the network trains. The choice of initializer has a bigger impact on networks without batch normalization layers.

Depending on the type of layer, you can change the weights and bias initialization using the WeightsInitializer, InputWeightsInitializer, RecurrentWeightsInitializer, and BiasInitializer options.

This example shows the effect of using these three different weight initializers when training an LSTM network:

  1. Glorot Initializer – Initialize the input weights with the Glorot initializer. [1]

  2. He Initializer – Initialize the input weights with the He initializer. [2]

  3. Narrow-Normal Initializer – Initialize the input weights by independently sampling from a normal distribution with zero mean and standard deviation 0.01.

Load Data

Load the Japanese Vowels data set that contains sequences of varying length with a feature dimension of 12 and a categorical vector of labels 1,2,...,9. The sequences are matrices with 12 rows (one row for each feature) and a varying number of columns (one column for each time step).

load JapaneseVowelsTrainData
load JapaneseVowelsTestData

Specify Network Architecture

Specify the network architecture. For each initializer, use the same network architecture.

Specify the input size as 12 (the number of features of the input data). Specify an LSTM layer with 100 hidden units and to output the last element of the sequence. Finally, specify nine classes by including a fully connected layer of size 9, followed by a softmax layer.

numFeatures = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer]
layers = 
  4x1 Layer array with layers:

     1   ''   Sequence Input    Sequence input with 12 dimensions
     2   ''   LSTM              LSTM with 100 hidden units
     3   ''   Fully Connected   9 fully connected layer
     4   ''   Softmax           softmax

Training Options

Specify the training options. For each initializer, use the same training options to train the network.

maxEpochs = 30;
miniBatchSize = 27;
numObservations = numel(XTrain);
numIterationsPerEpoch = floor(numObservations / miniBatchSize);

options = trainingOptions("adam", ...
    ExecutionEnvironment="cpu", ...
    MaxEpochs=maxEpochs, ...
    InputDataFormats="CTB", ...
    Metrics="accuracy", ...
    MiniBatchSize=miniBatchSize, ...
    GradientThreshold=2, ...
    ValidationData={XTest,TTest}, ...
    ValidationFrequency=numIterationsPerEpoch, ...
    Verbose=false, ...
    Plots="training-progress");

Glorot Initializer

Specify the network architecture listed earlier in the example and set the input weights initializer of the LSTM layer and the weights initializer of the fully connected layer to "glorot".

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="glorot")
    fullyConnectedLayer(numClasses,WeightsInitializer="glorot")
    softmaxLayer];

Train the network using the trainnet function with the Glorot weights initializers.

[netGlorot,infoGlorot] = trainnet(XTrain,TTrain,layers,"crossentropy",options);

He Initializer

Specify the network architecture listed earlier in the example and set the input weights initializer of the LSTM layer and the weights initializer of the fully connected layer to "he".

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="he")
    fullyConnectedLayer(numClasses,WeightsInitializer="he")
    softmaxLayer];

Train the network using the layers with the He weights initializers.

[netHe,infoHe] = trainnet(XTrain,TTrain,layers,"crossentropy",options);

Narrow-Normal Initializer

Specify the network architecture listed earlier in the example and set the input weights initializer of the LSTM layer and the weights initializer of the fully connected layer to "narrow-normal".

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="narrow-normal")
    fullyConnectedLayer(numClasses,WeightsInitializer="narrow-normal")
    softmaxLayer];

Train the network using the layers with the narrow-normal weights initializers.

[netNarrowNormal,infoNarrowNormal] = trainnet(XTrain,TTrain,layers,"crossentropy",options);

Plot Results

Extract the validation accuracy from the information structs output from the trainNetwork function.

validationAccuracy = [
    infoGlorot.ValidationHistory.Accuracy,...
    infoHe.ValidationHistory.Accuracy,...
    infoNarrowNormal.ValidationHistory.Accuracy];

The vectors of validation accuracy contain NaN for iterations that the validation accuracy was not computed. Remove the NaN values.

idx = all(isnan(validationAccuracy));
validationAccuracy(:,idx) = [];

For each of the initializers, plot the epoch numbers against the validation accuracy.

figure
epochs = 0:maxEpochs;
plot(epochs,validationAccuracy)
ylim([0 100])
title("Validation Accuracy")
xlabel("Epoch")
ylabel("Validation Accuracy")
legend(["Glorot" "He" "Narrow-Normal"],Location="southeast")

Figure contains an axes object. The axes object with title Validation Accuracy, xlabel Epoch, ylabel Validation Accuracy contains 3 objects of type line. These objects represent Glorot, He, Narrow-Normal.

This plot shows the overall effect of the different initializers and how quickly the training converges for each one.

Bibliography

  1. Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pp. 249-256. 2010.

  2. He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." In Proceedings of the IEEE international conference on computer vision, pp. 1026-1034. 2015.

See Also

| |

Related Topics