Training a multiple output segmentation network based on U-net - Staying in MATLAB?

3 views (last 30 days)
Hello!
I have created a neural network using the Deep Learning Toolbox in MATLAB that i want to train. The network has a backbone created by the uNetLayers-function. I have modified this backbone to have two outputs instead of one. One semantic segmentation output and one image-to-image regression output. An image of the final layers of the network can be seen below.
I want to train this network using MATLAB. The "Segmentation-Layer" output-layer has a cross-entropy loss function (LOSS1) and outputs a 128x128x2 image. The "regressionLayer" output-layer has a mean squared error loss function (LOSS2) and outputs a 128x128x1 image. I want to combine these losses by adding them together: LOSS = LOSS1 + LOSS2. The backpropagation should be performed using this combined loss function (LOSS).
However, the only example I can find for training multiple output networks and creating custom loss-functions is this:
Using this tutorial as a reference, I have to create a model function for the network. However, I don't find it straightforward how to modify this function for my own network structure that is significantly larger than the one used in the example. The model function for the network used in the tutorial looks like this:
function [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state)
% Convolution
weights = parameters.conv1.Weights;
bias = parameters.conv1.Bias;
dlY = dlconv(dlX,weights,bias,'Padding',2);
% Batch normalization, ReLU
offset = parameters.batchnorm1.Offset;
scale = parameters.batchnorm1.Scale;
trainedMean = state.batchnorm1.TrainedMean;
trainedVariance = state.batchnorm1.TrainedVariance;
if doTraining
[dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
% Update state
state.batchnorm1.TrainedMean = trainedMean;
state.batchnorm1.TrainedVariance = trainedVariance;
else
dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);
% Convolution, batch normalization (Skip connection)
weights = parameters.convSkip.Weights;
bias = parameters.convSkip.Bias;
dlYSkip = dlconv(dlY,weights,bias,'Stride',2);
offset = parameters.batchnormSkip.Offset;
scale = parameters.batchnormSkip.Scale;
trainedMean = state.batchnormSkip.TrainedMean;
trainedVariance = state.batchnormSkip.TrainedVariance;
if doTraining
[dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);
% Update state
state.batchnormSkip.TrainedMean = trainedMean;
state.batchnormSkip.TrainedVariance = trainedVariance;
else
dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);
end
% Convolution
weights = parameters.conv2.Weights;
bias = parameters.conv2.Bias;
dlY = dlconv(dlY,weights,bias,'Padding',1,'Stride',2);
% Batch normalization, ReLU
offset = parameters.batchnorm2.Offset;
scale = parameters.batchnorm2.Scale;
trainedMean = state.batchnorm2.TrainedMean;
trainedVariance = state.batchnorm2.TrainedVariance;
if doTraining
[dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
% Update state
state.batchnorm2.TrainedMean = trainedMean;
state.batchnorm2.TrainedVariance = trainedVariance;
else
dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);
% Convolution
weights = parameters.conv3.Weights;
bias = parameters.conv3.Bias;
dlY = dlconv(dlY,weights,bias,'Padding',1);
% Batch normalization
offset = parameters.batchnorm3.Offset;
scale = parameters.batchnorm3.Scale;
trainedMean = state.batchnorm3.TrainedMean;
trainedVariance = state.batchnorm3.TrainedVariance;
if doTraining
[dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
% Update state
state.batchnorm3.TrainedMean = trainedMean;
state.batchnorm3.TrainedVariance = trainedVariance;
else
dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end
% Addition, ReLU
dlY = dlYSkip + dlY;
dlY = relu(dlY);
% Fully connect (angles)
weights = parameters.fc1.Weights;
bias = parameters.fc1.Bias;
dlY2 = fullyconnect(dlY,weights,bias);
% Fully connect, softmax (labels)
weights = parameters.fc2.Weights;
bias = parameters.fc2.Bias;
dlY1 = fullyconnect(dlY,weights,bias);
dlY1 = softmax(dlY1);
end
QUESTION 1: Is there a built in function to create such a model function from a network or an easier way to generate it than doing it by hand?
Or is there an alternative way to train a network and modify the loss function the way that I want in MATLAB?
QUESTION 2: Since this is the first time I'm trying to develop a multiple output-network, I want to know if what I want to do is feasible in the MATLAB-environment. My main alternative is to export the created network using the ONNX-format to Keras/TensorFlow, train the network using one of these Python-platforms and then import it to MATLAB again using the ONNX-format. Would this be a better approach?
Thank you in advance.

Answers (0)

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!