How to Train Network on Image and Feature Data for regression
1 view (last 30 days)
Show older comments
dsX1Train = arrayDatastore(X1Train,IterationDimension=4);
dsX2Train = arrayDatastore(X2Train);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsX1Train,dsX2Train,dsTTrain);
%%
lgraph = layerGraph();
tempLayers = [
imageInputLayer([224 224 3],"Name","imageinput_1")
convolution2dLayer([3 3],8,"Name","conv_1","Padding","same")
batchNormalizationLayer("Name","batchnorm_1")
reluLayer("Name","relu_1")
averagePooling2dLayer([2 2],"Name","avgpool2d_1","Stride",[2 2])
convolution2dLayer([3 3],16,"Name","conv_2","Padding","same")
batchNormalizationLayer("Name","batchnorm_2")
reluLayer("Name","relu_2")
averagePooling2dLayer([2 2],"Name","avgpool2d_2","Stride",[2 2])
convolution2dLayer([3 3],32,"Name","conv_3","Padding","same")
batchNormalizationLayer("Name","batchnorm_3")
reluLayer("Name","relu_3")
convolution2dLayer([3 3],32,"Name","conv_4","Padding","same")
batchNormalizationLayer("Name","batchnorm_4")
reluLayer("Name","relu_4")
dropoutLayer(0.2,"Name","dropout")
fullyConnectedLayer(1,"Name","fc_1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
imageInputLayer([1 46 1],"Name","imageinput_2")
fullyConnectedLayer(1,"Name","fc_2")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
concatenationLayer(2,2,"Name","concat")
fullyConnectedLayer(1,"Name","fc_3")
regressionLayer("Name","regressionoutput")];
lgraph = addLayers(lgraph,tempLayers);
clear tempLayers;
lgraph = connectLayers(lgraph,"fc_2","concat/in1");
lgraph = connectLayers(lgraph,"fc_1","concat/in2");
%%
options = trainingOptions("sgdm", ...
MaxEpochs=15, ...
InitialLearnRate=0.001, ...
Plots="training-progress", ...
Verbose=0);
net = trainNetwork(dsTrain,lgraph,options);
I referenced this example:https://www.mathworks.com/help/deeplearning/ug/train-network-on-image-and-feature-data.html?s_tid=srchtitle_Train%20Network%20on%20Image%20and%20Feature_1
Warning: Training stops at iteration 3 because the training loss is NaN. Predictions using the output network may contain NaN values.
1 Comment
yanqi liu
on 14 Mar 2022
yes,sir,may be check the data to find NaN value,if possible,may be upload your data to analysis
Answers (0)
See Also
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!