MATLAB Answers

Trying to make a RNN for 2D signal data classification to 2D classified matrix output. Error using trainNetwork (line 165) Invalid training data. Responses must be a vector of categorical responses, or a cell array of categorical response sequences.

4 views (last 30 days)
Orestis Marantos
Orestis Marantos on 24 Mar 2021
Commented: Orestis Marantos on 22 Apr 2021
Hello Everyone,
I'm trying to implement a Neural Network classification algorithm for signal data with a shape (800,500) , with 800 being the number of time steps and 500 being the number of observations. I want to train a NN to identify for every timestep if it belongs to class 0, 1, or -1 . So, my responses should have the same shape with the XTrain data, (800,500).
After trying to use Xtrain and YTrain in the form of simple 2 dimensional arrays I understood that this is not possible, so I made changed their form to cell arrays as the bibliography requires.
My XTrain data have the form of a cell array of double sequences and so do the YTrain data. Out of 500 observations I used 70%, which are 350 observations for my Train data.
As the Matlab trainNetwork bibliography suggest:
But, I'm still getting the same error:
The Layers and Options for my RNN so far are the following, please make sugestions :)
Please help

Answers (1)

Srivardhan Gadila
Srivardhan Gadila on 29 Mar 2021
According to the documentation of the Input Arguments: sequences & responses of the trainNetwork function for the syntax net = trainNetwork(sequences,responses,layers,options) the input data should be of N-by-1 cell array of numeric arrays, where N is the number of observations and each observation must be a c-by-s matrix, where c is the number of features of the sequences and s is the sequence length in case of Vector sequences. Whereas the responses should be N-by-1 cell array of categorical sequences of labels, where N is the number of observations with each observation as a 1-by-s sequence of categorical labels, where s is the sequence length of the corresponding predictor sequence.
The following code may help you:
%% Create network.
inputSize = 800;
numClasses = 3;
numHiddenUnits = 100;
layers = [ ...
sequenceInputLayer(1,'Name','Sequence Input')
lstmLayer(numHiddenUnits,'Name','LSTM Layer')
classificationLayer('Name','Classification Layer')];
lgraph = layerGraph(layers);
%% Create Random Training data.
numTrainSamples = 50;
trainData = arrayfun(@(x)rand([1 inputSize]),1:numTrainSamples,'UniformOutput',false)';
trainLabels = arrayfun(@(x)categorical(randi([-1 1], 1,inputSize)),1:numTrainSamples,'UniformOutput',false)';
%% Train the network.
options = trainingOptions('adam', ...
'InitialLearnRate',0.005, ...
'Verbose',1, ...
net = trainNetwork(trainData,trainLabels,lgraph,options);
  1 Comment
Orestis Marantos
Orestis Marantos on 22 Apr 2021
Hello Srivardhan, thank you so much for your thorough explanations. Could you also recommend me if the way to overpass the high skewness of my problem? I created with simulations more samples.
Specifically, I have 2000 samples each with 800 time steps. In each sample there are some points of interest (signal goes up (on), signal goes down (off)). I would like to detect those points so I thought that MIMO multi-label classification with lstm could solve my problem. So, I want to classify each time step of each sample as 0 if it is not a point of interest, or 1 if it is (changed the -1 also to 1). Eventually I'm interested only in ones so I'm not sure if it is the best method to follow.
Because more than 98% of the timesteps are zeroes the lstm keeps giving my only zeroes after the training. I have implemented a custom weighted binary cross entropy function and I give 0.97 weight to the ones and 0.03 to zeros, but the algorithm keeps giving me only zeros as a result for each prediciton.
Do you have any suggestions on how to solve this highly skewed problem? I read about over-sampling and under-sampling but this won't help in my problem as it is natural to have this very little number of active points (ones).
Thanks a lot in advance

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!