How to get r-square,mean absolute error and mean square error after train neural network?

7 views (last 30 days)
Hi all, I train neural network as follow command
net.divideFcn = 'dividerand'
net.divideParam.trainRatio= 0.6;
net.divideParam.testRatio= 0.2;
net.divideParam.valRatio= 0.2;
[net,tr]=train(net,input,target);
I want to get r-square,mean absolute error and mean square error from train,test and validation data
Cloud you please advice ?

Answers (1)

Paras Gupta
Paras Gupta on 18 Jul 2024
Hi Ninlawat,
I understand that you want to compute different network performance metrics on the train, test, and validation data after training a neural network object in MATLAB.
The following code illustrates one way to achieve the same:
% dummy data
input = rand(1, 100); % 1 feature, 100 samples
target = 2 * input + 1 + 0.1 * randn(1, 100); % Linear relation with some noise
% Define the feedforward network
net = feedforwardnet(10); % 10 hidden neurons
% Set up the data division
net.divideFcn = 'dividerand';
net.divideParam.trainRatio = 0.6;
net.divideParam.valRatio = 0.2;
net.divideParam.testRatio = 0.2;
% Train the network
[net, tr] = train(net, input, target);
% Get the network outputs
outputs = net(input);
% Separate the outputs for training, validation, and testing
trainOutputs = outputs(tr.trainInd);
valOutputs = outputs(tr.valInd);
testOutputs = outputs(tr.testInd);
% Separate the targets for training, validation, and testing
trainTargets = target(tr.trainInd);
valTargets = target(tr.valInd);
testTargets = target(tr.testInd);
% Calculate and display R-square, MAE, and MSE for each dataset
datasets = {'train', 'val', 'test'};
outputsList = {trainOutputs, valOutputs, testOutputs};
targetsList = {trainTargets, valTargets, testTargets};
for i = 1:length(datasets)
dataset = datasets{i};
outputs = outputsList{i};
targets = targetsList{i};
% R-square
SS_res = sum((targets - outputs).^2);
SS_tot = sum((targets - mean(targets)).^2);
R_square = 1 - SS_res / SS_tot;
% Mean Absolute Error (MAE)
MAE = mae(targets - outputs);
% Mean Square Error (MSE)
MSE = mse(net, targets, outputs);
% Display the results
fprintf('%s R-square: %.4f\n', dataset, R_square);
fprintf('%s MAE: %.4f\n', dataset, MAE);
fprintf('%s MSE: %.4f\n', dataset, MSE);
fprintf('\n');
end
You can refer the following documentation links for more infromation on the properties and functions used in the code above:
Hope this helps.

Categories

Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!