MATLAB Answers

How to train semantic segmentation network to recognize one class?

13 views (last 30 days)
Kevin Petersen
Kevin Petersen on 17 Mar 2021
Answered: Srivardhan Gadila on 26 Mar 2021
I am training a semantic segmentation network (I'm trying SegNet) to classify skin in an image. This is the only class that will need to be labeled and I labeled all my training and testing data using the Image Labeling App in Matlab, so if you load an image you will see the background as zeros and skin as ones.
However, I didn't specify a 'background' class as that seemed unnecessary, I just didn't label it. So when I try to build a network with one class (calling lgraph = segnetLayers(imageSize,numClasses,model) ) using one as the number of classes it says the number of classes needs to be greater then one. Is this just an attribute of these types of networks that you need to have two or more classes for the network to classify? If I set it to two it obviously fails during training but I'm also not sure if labeling everything else as 'background' will yield positive results (I'm thinking it may just classify everything as background). Thoughts?

Answers (1)

Srivardhan Gadila
Srivardhan Gadila on 26 Mar 2021
For any type of classification task there should atleast two classes in general or a single class with outputs as Yes or No (Binary). In case of segmentation task for a single class a pixel would be classified as Yes or No i.e., class name for an Yes and background for a No.
If you want to train a network without having background labelled explicitly then create the lgraph using segnetLayers function, remove the softmax and the loss layers:
lgraph = removeLayers(lgraph,{'softmax','pixelLabels'});
Replace the final convolution2dLayer with a new convolution2dLayer whose numFilters should be one
replaceDecoderConv1Layer = convolution2dLayer(3,1,'Name','decoder1_conv1','Stride',[1 1],'Padding',[1 1 1 1]);
lgraph = replaceLayer(lgraph,'decoder1_conv1',replaceDecoderConv1Layer);
dlnet = dlnetwork(lgraph)
Instead you only want to address the class imbalance problem then for the lgraph you can use a focalLossLayer or Use class weighting.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!