My ONNX network doesn't work when loaded in my Java application

1 view (last 30 days)
Hi guys! I'm developping a Java application that uses a CNN, precisely an AlexNet. I followed these tutorials https://it.mathworks.com/matlabcentral/fileexchange/62990-deep-learning-tutorial-series , https://it.mathworks.com/help/deeplearning/ref/alexnet.html and https://it.mathworks.com/help/deeplearning/ref/trainnetwork.html . Now a weird thing occurs: I generated the network with a high accuracy (96.5% on a dataset "never seen before", that comes from the environment that the network is thought to work on). Next, I loaded the model in my Java application, using the OpenCV method readNetFromONNX ( https://docs.opencv.org/master/javadoc/org/opencv/dnn/Dnn.html ) but the network completely misclassifies. It's not a wrong coding of the classes, it simply "randomly" classifies. The thing is that, if I generate (with the same code, and same dataset) a network with a lower accuracy, maybe with a bad choice of training parameters (I mean about 90%), it works fine when loaded (by "fine" I mean at 90%, obviously). I thought it was overtraining, and, in fact, I trained the network with the test set getting an accuracy of (surprise surprise) 99.89%, but when I load it on my Java application (where, some inputs coming from the test set are given) the same thing occurs, it completely missclassifies everything that come as input. It's like Java doesn't accept the network that has a high accuracy.

Accepted Answer

Gabija Marsalkaite
Gabija Marsalkaite on 10 Jul 2019
Hi Luigi,
One possibility is differences in normalisation of data - Alexnet has zero-center normalisation in the first layer in MATLAB but it may work slightly differently when used with Java. According to documentation:
  • 'zerocenter' — Subtract the average image specified by the AverageImage property. The trainNetwork function automatically computes the average image at training time.
I expect it should work if you preprocess images this way. Let me know if that solves your issue.

More Answers (1)

Luigi Treccozzi
Luigi Treccozzi on 11 Jul 2019
Thanks very much for answering. I tried to execute these lines of code before training the network:
for i = 1:227
for j = 1:227
mea(i, j, 1) = 126.4784;
mea(i, j, 2) = 109.7040;
mea(i, j, 3) = 108.7779;
end
end
layers(1) = imageInputLayer([227 227 3],'Name','data', 'AverageImage', mea);
But the same problem occurs. I also tried to use no normalisation since no modification (but a resize to fit the input that might be by its nature of different sizes) is performed on the images in Java.
To find those means I used this code
sumRedChannel = 0;
sumGreenChannel = 0;
sumBlueChannel = 0;
numPix = 0;
for k = 1 : num_files
baseFileName = files(k).name;
fullFileName = fullfile('free/', baseFileName);
%fprintf(1, 'Now reading %s\n', fullFileName);
if mod(k,1000) == 0
fprintf(1, '+1000-Now reading %s\n', fullFileName);
end
% such as reading it in as an image array with imread()
rgbImage = imread(fullFileName);
[rows, columns, numberOfColorChannels] = size(rgbImage);
if numberOfColorChannels ~= 3
fprintf(1, 'Now reading %s\n', fullFileName);
disp(numberOfColorChannels);
end
redChannel = rgbImage(:,:,1);
greenChannel = rgbImage(:,:,2);
blueChannel = rgbImage(:,:,3);
sumRedChannel = sumRedChannel + sum(redChannel(:));
sumGreenChannel = sumGreenChannel + sum(greenChannel(:));
sumBlueChannel = sumBlueChannel + sum(blueChannel(:));
numPix = numPix + rows*columns;
end
meanRed = sumRedChannel/numPix;
meanGreen = sumGreenChannel/numPix;
meanBlue = sumBlueChannel/numPix;
I have no idea what the problem could be

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!