How can I do mutli-class classification with the 3D Unet ?

3 views (last 30 days)
The 3D Unet segmentation example features a binary class classification.
I was tying to extend the example to multi-class classification but I kept on having a constant loss function.
Was anyone able to perform multi-class classification with the 3D unet in matlab ?

Answers (1)

Shashank Gupta
Shashank Gupta on 27 Aug 2019
Multiclass classifiers are very similar to binary classifier, you may need to change the last layer of your model to make the multiclass classifier output compatible with your model. There is a function available in MATLAB "pixelLabelDatstore", which can generate the pixel label images that in turn may be used as a label data target in your network for semantic segmentation.
Also, there can be many reasons to get a constant loss function, Data imbalance could be one. Try using a weighted multiclass Dice loss function instead of “crossentropy”.
If that does not help, try using an adaptive learning rate for your network. Also check the target images before feeding it to your network, sometimes the target and predictive images comes out to be transpose of each other because of how the MATLAB handles the data.
May be 3D tumor segmentation example can help you set up your model.
  3 Comments
Shashank Gupta
Shashank Gupta on 30 Aug 2019
Hi Attallah,
I am sorry but it’s difficult to pin point any specific reason of getting a constant loss function, there could be many. Although you can do some more research in your model and see what causing the problem, Try Visualizing “softmax” output instead of looking directly at classes and see if you can find any pattern. It’s also possible that the model underfit (rarely happens). Also check the “bias” term in each layer, see if it is not sufficiently large (this makes your layer output zero). It can also happened that the optimizer stuck at some saddle point and not able to come out from there, May be a different optimizer can help (although I can safely assume you must have tried this).
I cannot think of any more reason as of now.
Hope this give you some lead
Atallah Baydoun
Atallah Baydoun on 6 Nov 2019
Hey Shashank,
Another technical question came up and I was wondering if you can help with understanding the choice of data for the minibatch.
Let us assume that we have 20 images, and we chose only one patch per image. This will give us a total of 20 patches.
Let us also suppose that we chose our minibatch size to be 5.
At each iteration, trainnetwork will choose 5 patches among the 20 to create its minibatch.
How is the selection process done ? Is it completely random ? I have tried to debug the trainNetwork code but I couldn't find anything ?
Thanks,
Atallah

Sign in to comment.

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!