Why Are Hidden State and Cell State Vectors Zero After Training an LSTM Model with trainNetwork Functionality?
9 views (last 30 days)
Show older comments
Shubham Baisthakur
on 10 Oct 2023
Commented: Shubham Baisthakur
on 20 Oct 2023
I am training an LSTM model using the trainNetwork functionality and follwing is the architecture of my model:
layers = [ ...
sequenceInputLayer(size(X_train{1},1))
layerNormalizationLayer
lstmLayer(x.num_hidden_units,'OutputMode','sequence')
fullyConnectedLayer(x.num_layers_ffnn)
dropoutLayer(0.1)
fullyConnectedLayer(1)
regressionLayer];
And I am training this using the following command:
options = trainingOptions('adam', ...
'MaxEpochs', 75, ...
'MiniBatchSize', x.batch_size, ...
'SequenceLength', 'longest', ...
'Shuffle', 'once', ...
'L2Regularization',0.01,...
'ValidationData',{X_val,Y_val}, ...
'ValidationFrequency',10,...
'Verbose',false,...
'ExecutionEnvironment','multi-gpu');
% Train the LSTM network
net = trainNetwork(X_train, Y_train, layers, options);
After training the model, the Hidden state and Cell state values for the LSTM layer is a vector of zeros. Why is this happening? I expect these vectors to have non-zero values to ensure the long term dependency between input and output parameters is captured.
0 Comments
Accepted Answer
Neha
on 20 Oct 2023
Hi Shubham,
The LSTM (Long Short-Term Memory) layer in a neural network is designed to remember values over arbitrary time intervals which indeed helps in maintaining and learning long-term dependencies. However, after training, the hidden and cell states of the LSTM layer are reset to zero. This is standard behavior for LSTMs, and it doesn't mean that the LSTM layer has not learned anything or that it's not working properly.
If you want to maintain the state of LSTM for some reason (like in case of time series prediction where you want the model to remember the state from the previous sequence), you can refer to the explanation for Open Loop Forecasting and Closed Loop Forecasting in the following documentation link:
Here "predictAndUpdateState" function has been used which updates the network state at every timestep.
Hope this helps!
More Answers (0)
See Also
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!