deep learning toolbox : Input format mismatch

11 views (last 30 days)
I want to build a lstm network with an attention mechanism.
According to a paper, I need to add the output layer of lstm with the processed hidden layer.
But when I use the additionlayer I get a message that the input format doesn't match!
I found that the output layer is in 10(C) x 1(B) x 1(T) format, while the hidden layer is in 10(C) x 1(B) format.
Normally they should be addable because they are the same size in the first two dimensions and the output layer has a length of 1 in the third dimension.
What do I need to do to make them add up. Thanks!
Following is the deep network I designed according to the paper. It is a CNN-LSTM-Attention network. It is used to predict a time series sequences.
sequence = sequenceInputLayer([kim,1],"Name",'sequence');
conv1 = convolution1dLayer(3,64,'Name','conv1','Padding','causal');
batchnorm1 = batchNormalizationLayer('Name','batchnorm1');
relu1 = reluLayer('Name','relu1');
conv2 = convolution1dLayer(3,64,'Name','conv2','Padding','causal');
batchnorm2 = batchNormalizationLayer('Name','batchnorm2');
relu2 = reluLayer('Name','relu2');
conv3 = convolution1dLayer(3,64,'Name','conv3','Padding','causal');
batchnorm3 = batchNormalizationLayer('Name','batchnorm3');
add1 = additionLayer(2,"Name",'add1');
relu3 = reluLayer('Name','relu3');
maxpool = maxPooling1dLayer(10,'Name','maxpool');
flatten1 = flattenLayer("Name",'flatten1');
lstm = lstmLayer(10,'Name','lstm','HasStateOutputs',1);
fc1 = fullyConnectedLayer(10,'Name','fc1');
tanh1 = tanhLayer('Name','tanh1');
softmax1 = softmaxLayer('Name','softmax1');
multiplication1 = multiplicationLayer(2,'Name','multiplication1');
fc2 = fullyConnectedLayer(10,'Name','fc2');
tanh2 = tanhLayer('Name','tanh2');
softmax2 = softmaxLayer('Name','softmax2');
multiplication2 = multiplicationLayer(2,'Name','multiplication2');
add2 = additionLayer(3,"Name",'add2');
fc3 = fullyConnectedLayer(10,'Name','fc3');
softmax3 = softmaxLayer('Name','softmax3');
regression = regressionLayer('Name','regression');
lgraph = layerGraph;
lgraph = addLayers(lgraph,sequence);
lgraph = addLayers(lgraph,conv1);
lgraph = addLayers(lgraph,batchnorm1);
lgraph = addLayers(lgraph,relu1);
lgraph = addLayers(lgraph,conv2);
lgraph = addLayers(lgraph,batchnorm2);
lgraph = addLayers(lgraph,relu2);
lgraph = addLayers(lgraph,conv3);
lgraph = addLayers(lgraph,batchnorm3);
lgraph = addLayers(lgraph,add1);
lgraph = addLayers(lgraph,relu3);
lgraph = addLayers(lgraph,maxpool);
lgraph = addLayers(lgraph,flatten1);
lgraph = addLayers(lgraph,lstm);
lgraph = addLayers(lgraph,fc1);
lgraph = addLayers(lgraph,tanh1);
lgraph = addLayers(lgraph,softmax1);
lgraph = addLayers(lgraph,multiplication1);
lgraph = addLayers(lgraph,fc2);
lgraph = addLayers(lgraph,tanh2);
lgraph = addLayers(lgraph,softmax2);
lgraph = addLayers(lgraph,multiplication2);
lgraph = addLayers(lgraph,add2);
lgraph = addLayers(lgraph,fc3);
lgraph = addLayers(lgraph,softmax3);
lgraph = addLayers(lgraph,regression);
lgraph = connectLayers(lgraph,'sequence','conv1');
lgraph = connectLayers(lgraph,'conv1','batchnorm1');
lgraph = connectLayers(lgraph,'batchnorm1','relu1');
lgraph = connectLayers(lgraph,'batchnorm1','add1/in1');
lgraph = connectLayers(lgraph,'relu1','conv2');
lgraph = connectLayers(lgraph,'conv2','batchnorm2');
lgraph = connectLayers(lgraph,'batchnorm2','relu2');
lgraph = connectLayers(lgraph,'relu2','conv3');
lgraph = connectLayers(lgraph,'conv3','batchnorm3');
lgraph = connectLayers(lgraph,'batchnorm3','add1/in2');
lgraph = connectLayers(lgraph,'add1/out','relu3');
lgraph = connectLayers(lgraph,'relu3','maxpool');
lgraph = connectLayers(lgraph,'maxpool','flatten1');
lgraph = connectLayers(lgraph,'flatten1','lstm');
lgraph = connectLayers(lgraph,'lstm/out','add2/in1');
lgraph = connectLayers(lgraph,'lstm/hidden','fc1');
lgraph = connectLayers(lgraph,'fc1','tanh1');
lgraph = connectLayers(lgraph,'tanh1','softmax1');
lgraph = connectLayers(lgraph,'softmax1','multiplication1/in1');
lgraph = connectLayers(lgraph,'lstm/hidden','multiplication1/in2');
lgraph = connectLayers(lgraph,'multiplication1','add2/in2');
lgraph = connectLayers(lgraph,'lstm/cell','fc2');
lgraph = connectLayers(lgraph,'fc2','tanh2');
lgraph = connectLayers(lgraph,'tanh2','softmax2');
lgraph = connectLayers(lgraph,'softmax2','multiplication2/in1');
lgraph = connectLayers(lgraph,'lstm/cell','multiplication2/in2');
lgraph = connectLayers(lgraph,'multiplication2','add2/in3');
lgraph = connectLayers(lgraph,'add2','fc3');
lgraph = connectLayers(lgraph,'fc3','softmax3');
lgraph = connectLayers(lgraph,'softmax3','regression');
plot(lgraph)
analyzeNetwork(lgraph)

Accepted Answer

David Ho
David Ho on 9 Oct 2023
Hello 湃林,
It looks like the error in the multiplication layer comes from the fact that one of the inputs, from the LSTM layer, has a time dimension, while the other two inputs do not. That means that if, for example, your sequences have length 100, the layer is trying to multiply a [10 1] matrix, a [10 1] matrix, and a [10 1 100] array, which is not a supported operation (multiplication layers do not permit implicit expansion).
One method of resolving the issue is to change the output mode of the LSTM layer from "sequence" to "last", so that it only outputs a single time step:
lstm = lstmLayer(10,'Name','lstm','HasStateOutputs',1,'OutputMode','last');
However, I'm not sure if this is the result you need. If the time steps are important and you want to perform multiplication with implicit expansion, you can substitute the multiplication layer (and the following addition layer) for a functionLayer:
multiplication1 = functionLayer(@(x,y,z) x .* y .* z, 'Formattable', true);
add2 = functionLayer(@(x,y) x + y, 'Formattable', true);
Hopefully one of these options will give you the result you're looking for.
As a final note, if you want to implement the self-attention mechanism from the Attention is all you need paper, you may wish to look at Deep Learning Toolbox's implementation of the self-attention layer:
This handles the matrix multiplication of the keys, queries and values internally, so you don't need to implement it yourself.
Best regards,
David
  1 Comment
湃 林
湃 林 on 9 Oct 2023
Thanks so much! I think I will solve this problem soon with your help.

Sign in to comment.

More Answers (0)

Products


Release

R2023a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!