Having high loss function with the custom training loop.
5 views (last 30 days)
Show older comments
Good day everyone,
I'm currently working on a custom training loop for cardiac segmentation. However, I'm encountering extremely high loss values during training when using the crossentropy function.
function [loss,Y] = modelLoss_test(net,X,T)
Y = forward(net,X);
loss = crossentropy(Y,T);
end
I've checked that X (size: 256 x 208 x 1 x 2) and T (size: 256 x 208 x 2 x 2) are both in 4-D dlarray. Both Y and T have max value of 1 and min value of 0. However, when directly calculated with the function “loss = crossentropy(Y,T)” the loss value given was extremely high (e.g. 4.123 x 10^5). On the other hand, when I compute the loss manually using the following code, I get a more reasonable value (e.g., 15.356):
yy = gather(extractdata(Y(:)));
tt = gather(extractdata(T(:)));
loss = crossentropy(yy,tt);
For context, I'm using a U-Net with the Adam optimizer. I replaced the final convolution layer of the U-Net with a layer that has 2 output channels:
lgraph = replaceLayer(lgraph, 'Final-ConvolutionLayer', convolution2dLayer(3, 2, 'padding','same','Name', 'Segmentation-Layer'));
I also tried incorporating class weights into the loss function (which resulted in an insignificant reduction in the loss value):
weights = [0.95, 0.05];
loss = crossentropy(Y,T,weights,WeightsFormat="BC");
Could someone explain why there is such a large difference in loss values when using MATLAB's built-in crossentropy function versus my manual calculation? I would greatly appreciate any advice or solutions to this problem. Thank you in advance!
4 Comments
Answers (1)
Jayanti
on 16 Oct 2024
Hi Hui,
Let’s start by analysing the difference between the MATLAB in-built and custom loss function.
Generally, in image segmentation and classification task true labels are provided in one-hot encoding format. The built-in function is interpreting “T” (True label) as one-hot encoded vector while calculating the loss.
Whereas in custom loss function you have extracted data from “dlarray” and stored it in “tt” variable. Now since data has been extracted from “dlarray”, while passing it in cross entropy function it will not be treated as one hot encoded vector. Hence, both losses will result in two different values.
If you want to calculate loss on extracted values, then you can calculate it using below code. This will give you the same loss value as the built-in cross entropy function.
yy = gather(extractdata(Y(:)));
tt = gather(extractdata(T(:)));
loss_array = -sum(tt .* log(yy));
I tried running the code using the above custom loss function and it is giving the same results as that of built-in cross entropy function.
0 Comments
See Also
Categories
Find more on Image Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!