How to add new classes to a neural network?
3 views (last 30 days)
Show older comments
I made myself a network for flowers recognition. It's pretty much a copy of Alex net, but with some layers deleted. I trained it with 5 classes, but now i want to add more. How can i do that without retrain it from 0?
allImages = imageDatastore('D:\stuff machine learning\flowers', 'IncludeSubfolders', true,... 'LabelSource', 'foldernames');
[trainingImages, testImages] = splitEachLabel(allImages, 0.8, 'randomize');
conv1 = convolution2dLayer(11,96,'Stride',4,'Padding',0); %290.5k neuroni conv2 = convolution2dLayer(5,256,'Stride',1,'Padding',2); %7milioane neuroni conv3 = convolution2dLayer(3,384,'Stride',1,'Padding',1); conv4 = convolution2dLayer(3,384,'Stride',1,'Padding',1); conv5 = convolution2dLayer(3,256,'Stride',1,'Padding',1); layers = [... imageInputLayer([227 227 3]); conv1; reluLayer('Name','relu1'); maxPooling2dLayer(3,'Name','pool1','Stride',2); conv2; reluLayer('Name','relu2'); maxPooling2dLayer(3,'Name','pool2','Stride',2); conv3; reluLayer('Name','relu3'); conv4; reluLayer('Name','relu4'); conv5; reluLayer('Name','relu5'); maxPooling2dLayer(3,'Name','pool5','Stride',2); fullyConnectedLayer(4096,'Name','fc6'); reluLayer('Name','relu6'); dropoutLayer('Name','drop6'); fullyConnectedLayer(4096,'Name','fc7'); reluLayer('Name','relu7'); dropoutLayer('Name','drop7'); fullyConnectedLayer(5,'Name','fc8'); softmaxLayer('Name','prob'); classificationLayer('Name','output');]
opts = trainingOptions('sgdm', ... 'InitialLearnRate', 0.001, ... 'LearnRateSchedule', 'piecewise', ... 'LearnRateDropFactor', 0.1, ... 'LearnRateDropPeriod', 10, ... 'L2Regularization', 0.008, ... 'MaxEpochs', 30, ... 'MiniBatchSize', 40, ... 'ValidationData',testImages, ... 'Verbose', true,... 'Plot','training-progress');
testImages.ReadFcn = @readFunctionTrain1; trainingImages.ReadFcn = @readFunctionTrain1; %antrenarea retelei myNet = trainNetwork(trainingImages, layers, opts);
[YPred,probs] = classify(myNet,testImages); accuracy = mean(YPred == testImages.Labels)
idx = randperm(numel(testImages.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(testImages,idx(i)); imshow(I) label = YPred(idx(i)); title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%"); end
This is the network
1 Comment
Balakrishnan Rajan
on 16 Oct 2018
I am trying to do the same thing. Theoretically this should be done by changing the dimension of the Weights matrix, Bias vector and the OutputSize of the fully connected layer and the OutputSize of the classoutput layer and add the new category label to the Classes object. However, these properties are set to read-only.
Peter Gadfort provided a solution in this thread. However, I cant change the OutputSize as this is still a read-only property. If you do find a solution, please post it.
The code I am trying is this:
% Adding new classes to a trained net
%%Create an editable net object
load('BestNet.mat')
TempNet = net.saveobj;
%%Edit the properties of the fully connected layer
FCLayer = TempNet.Layers(142,1);
FCOutputSize = FCLayer.OutputSize;
FCLayer.OutputSize = FCOutputSize+1;
FCWeights = FCLayer.Weights;
FCWsize = size(FCWeights)
FCLayer.Weights = rand(FCWsize(1)+1, FCWsize(2));
FCLayer.Weights(1:FCWsize(1),:) = FCWeights;
FCBias = FCLayer.Bias;
FCLayer.Bias = rand(size(FCBias)+1);
FCLayer.Bias(1:size(FCBias)) = FCBias;
%%Edit the properties of the output layer
OutputLayer = TempNet.Layers(144,1);
OLOutputSize = OutputLayer.OutputSize;
OutputLayer.OutputSize = OLOutputSize + 1;
OLClasses = OutputLayer.Classes;
OLClasses(size(OLClasses)+1) = 'Obstructed';
%%Make this the net
net = load.obj(TempNet);
The pretrained net that I am using is the GoogLeNet derivative with the last three layers changed to a fully connected layer, a softmax layer followed by a crossentropy loss. I am adding a new class called "obstructed". Alphabetically sorted, this is the last class which is why I add the new elements to the end of the older elements.
Answers (0)
See Also
Categories
Find more on Image Data Workflows in Help Center and File Exchange
Products
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!