How to get the Class Names from a dlNetwork?

34 views (last 30 days)
Hi All,
I am moving my code from 2023b, to 2024b, and I want to make use of dlNetworks, and the trainnet function in general, to take advantage of the better performance. Most of the code has been fine, except for the move away from classify to minibatchpredict and scores2label.This is my current code, with a series network implementation:
yPred = classify(Model, xAll);
However, classify itself doesn't work with dlNetworks, and as such I have to move on to something like this, according to the documentation:
yPred = predict(Model, cell2mat(xAll)');
Predictions = scores2label(yPred,classNames);
The issue with this is that I have to explicitly save my classNames down, which I don't have to do in the seriesnetwork implementation, although I am aware that it is saved as a private property.
I am just asking is there a way that I can do something similar, where I don't have to save the classnames themselves? Otherwise it seems that seriesnetworks might be a better implementation.
Thanks in advance!

Accepted Answer

Matt J
Matt J on 25 Nov 2024 at 18:06
Edited: Matt J on 25 Nov 2024 at 22:05
After you've trained the model as a dlnetwork, there is nothing stopping you from converting it back to a traditional SeriesNetwork and adding a classification layer for the purposes of prediction.
You could also define for yourself a custom layer,
to replace the traditional output layer. The idea is that you replace it with a layer which stores your class names and which does essentially the same thing as the legacy classification output layer, but because it is derived from the nnet.layer.Layer parent class, it would be compatible with dlnetwork(). With this approach, you would write the forward() method for the custom layer to ignore the class names for the purposes of training, while the predict() method would be written to use them.

More Answers (0)

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Products


Release

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!