MATLAB Answers

How to classify with DAG network from checkpoint

53 views (last 30 days)
Yoshinori Abe
Yoshinori Abe on 12 Oct 2018
Commented: Katja Mogalle on 19 Oct 2021 at 10:58
I want to use classify() with DAG network from checkpoint network.
I trained inceptionv3 by transfer learning for a long epochs and it was successed. I set 'CheckpointPath' and have networks at each epoch. I want to evaluate these networks, so I loaded one and used classify(). But error message occuerd and it said "Use trainNetwork". How can I use classify() with network loaded from checkpoint?
  3 Comments
carlos arizmendi
carlos arizmendi on 23 Nov 2019
I have now the same problem classifing, how did you fix this bug? Thanks a lot.

Sign in to comment.

Accepted Answer

Naoya
Naoya on 15 Oct 2018
Thank you very much for providing the details.
The checkpoint network containing BatchNormalization layers is not supported on the current latest release (R2018b). I will forward this functionality as an enhancement request to our development team.
We applogize for causing inconvenience on the current checkpoint functionality.
  5 Comments
Wes Baldwin
Wes Baldwin on 7 Jul 2020
I just had this same issue in 2019b. How is this an enhancement? This is a bug that needs fixed!

Sign in to comment.

More Answers (3)

Katja Mogalle
Katja Mogalle on 30 Apr 2021
@Gediminas Simkus had the right idea for the workaround. I can sketch this out a bit more.
Background information
To make predictions with the network after training, batch normalization requires a fixed mean and variance to normalize the data. By default, this fixed mean and variance is calculated from the training data at the very end of training using the entire training data set. But when using checkpointing, the end of training isn't reached so the mean and variance values are not set.
Two possible solutions
There are two things you can try in order to use checkpoint networks for inference:
  1. Since R2021a, running statistics can be enabled for batch normalization layers. The batch normalization statistics are then calculated during training and not at the end of training. The checkpoint networks can be used directly without further modification. To do this, set the BatchNormalizationStatistics name-value pair in trainingOptions to ‘moving’ when training the network with checkpointing.
  2. Use trainNetwork with minimal training to convert the checkpoint network into a network with fixed batch normalization mean and variance that can be used for inference. The workaround is based on the process to Resume Training from Checkpoint Network but with some slight tweaks in order to modify the checkpointed network as little as possible.
Example steps for second workaround using trainNetwork (tested in R2020a and R2020b)
Load the checkpoint network into the workspace (replace this with your own file).
load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')
Specify the training options such that training is only run for one iteration, the input data statistics of the input layer are not recomputed, and the learnable parameters are only changed minimally.
options = trainingOptions('sgdm', ...
'InitialLearnRate',eps, ...
'ResetInputNormalization',false,...
'OutputFcn',@(~)true );
Now “resume” training using the layers of the checkpoint network you loaded with the new training options. If the checkpoint network is a DAG network, then use layerGraph(net) as the argument instead of net.Layers.
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);
The returned network can be used for inference.
YPred = classify(net2,XTrain);
I hope this helps.
  2 Comments
Katja Mogalle
Katja Mogalle on 19 Oct 2021 at 10:41
The option 'ResetInputNormalization' of training options was added in R2019b.

Sign in to comment.


AnaMota
AnaMota on 27 Apr 2021
Any solution on this? I am facing the same issue with MATLAB2020...

Andrea Daou
Andrea Daou on 8 Oct 2021 at 11:50
Hello,
I know an answer was accepted for this question but I have a response that might be useful.
If the use of network from checkpoint does not work in your MATLAB version, you can write a function similar to the one in https://fr.mathworks.com/help/deeplearning/ug/customize-output-during-deep-learning-training.html .
For example, instead of being based on Validation Accuracy, it can be based on Validation Loss.
function stop = stopIfValidationLossNotDecreasing(info,N,StartPoint)
stop = false;
% Keep track of the validation loss and the number of successive validations for which
% there has not been a decrease in the loss.
persistent ValLoss
persistent valLag
% Clear the variables when training starts.
if info.State == "start"
ValLoss = StartPoint; % Value chosen depending on the problem case; check first validation loss.
valLag = 0;
elseif ~isempty(info.ValidationLoss)
% Compare the current validation loss to the last validation loss; if
% the new validation loss is less than the validation loss that
% precedes it then reset valLag else increment valLag by 1. Now the new
% ValLoss to compare with is the last one reached.
if info.ValidationLoss < ValLoss
valLag = 0;
ValLoss = info.ValidationLoss;
else
valLag = valLag + 1;
ValLoss = info.ValidationLoss;
end
% If the validation lag is at least N, that is, the validation loss
% has not decreased for at least N validations in a row, then return true and
% stop training.
if valLag >= N
stop = true;
end
end
end
  1 Comment
Katja Mogalle
Katja Mogalle on 19 Oct 2021 at 10:58
Hi Andrea,
The training option "ValidationPatience" actually does exactly what you're showing in your code. To stop training when the loss on the validation set stops decreasing, simply specify validation data and a validation patience using the 'ValidationData' and the 'ValidationPatience' name-value pair arguments of trainingOptions, respectively. The validation patience is the number of times that the loss on the validation set can be larger than or equal to the previously smallest loss before network training stops.
But perhaps I am not fully understanding what you are trying to achieve. In that case, perhaps you could provide some clarification?
Thanks

Sign in to comment.

Products


Release

R2018b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!