Deep Learning Semantic Segmentation Example

2 views (last 30 days)
Ryan Rizzo
Ryan Rizzo on 1 Dec 2018
Answered: Sourabh on 11 Jun 2025
In order to familiarize myself with semantic segmentation and convolutional neural networks I am going through this tutorial by MathWorks:
I did not use the pretrained version of Segnet since I wanted to test on my custom data set. All code is the same, however I have different classes, and **fewer labels**. Below image shows the label name and amount of pixels associated with each.
To make up for the low pixel data for class 2, median frequency balancing was performed.
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount
classWeights = median(imageFreq) ./ imageFreq
I proceed to train the network using the code provided in the example with the `options` and `lgraph` unchanged. The SegNet network is created with weights initialized from the VGG-16 network
Unlike the example, I get a much lower global accuracy:
To gain further insight I plotted the Mini-batch accuracy and Mini-batch loss against each iteration.
It is clearly seen that the accuracy fluctuates wildly and ends up worse than it started, so the network learned absolutely nothing! However the loss decreased gradually.
A possible solution I propose would be to use inverse frequency balancing. However, in the example above, median frequency balancing was already performed, so I doubt how much this would help.
Is the terrible performance related to simply not having enough training data? Can anything be be done to improve performance with existing data?
Any suggestions are greatly appreciated.

Answers (1)

Sourabh
Sourabh on 11 Jun 2025
The graph tells that your network is optimizing (loss is decreasing) but not generalizing or learning meaningful class boundaries. You are already trying to fix this using median frequency balancing, which is good, but in very low data scenarios, it can overcompensate, making the model oscillate or diverge.
Instead, you can try smoothing weights:
epsilon = 1e-6;
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ (imageFreq + epsilon);
or try log-scaled weights:
totalPixels = sum(tbl.PixelCount);
classWeights = log(1 + totalPixels ./ tbl.PixelCount);
Segmentation tasks are data hungry. A small dataset can mean poor generalization, incomplete coverage of class variations and noisy or unstable learning. To solve this, you can perform Data Augmentation using MATLAB “imageDataAugmenter”. Apply it as:
imageAugmenter = imageDataAugmenter( ...
'RandRotation', [-20,20], ...
'RandXTranslation', [-10 10], ...
'RandYTranslation', [-10 10], ...
'RandXReflection', true);
Sometimes, unstable training might be caused by learning rate being too high or batch size being too small. You can try reducing “InitialLearnRate” to 1e-4 or lower and use a larger mini-batch size if possible.
For more information and examples on “imageDataAugmenter”, kindly refer the following MATLAB documentation:

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!