- Ensure that the dimensions of WQ, WK, WV, and WO align correctly with the input dimensions. The dimensions should match the expected sizes for matrices.
- Ensure that the outputs from the fc_stft and fc_cwt layers are compatible with the input dimensions expected by your crossAttention function. As your layers end with fullyconnected layers with output size of 100, check the outputSize variable once again and if its matching the expected output size.
- The output of fc_stft and fc_cwt layers should be connected to the inputs of crossAttention instead of directly to the attention layer.
- Try using MATLAB's pagemtimes function for matrix layer multiplication of multi-dimensional arrays like queries, keys and values in the implementation. Here is the MathWorks documetation link for the same: https://www.mathworks.com/help/matlab/ref/pagemtimes.html
how to make cross attention use attentionlayer?
10 views (last 30 days)
Show older comments
I want to replace the dual-branch merge section of the model in the following link with cross-attention for fusion, but it's not successful. Is my operation incorrect? I have written an example, but I still don't understand how to embed it into the model in the link.
net one:(failure, loss dont down)
initialLayers = [
sequenceInputLayer(1, "MinLength", numSamples, "Name", "input", "Normalization", "zscore", "SplitComplexInputs", true)
convolution1dLayer(7, 2, "stride", 1)
];
stftBranchLayers = [
stftLayer("TransformMode", "squaremag", "Window", hann(64), "OverlapLength", 52, "Name", "stft", "FFTLength", 256, "WeightLearnRateFactor", 0 )
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="stft_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "stft_conv_1")
layerNormalizationLayer("Name", "stft_layernorm_1")
reluLayer("Name", "stft_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "stft_conv_2")
layerNormalizationLayer("Name", "stft_layernorm_2")
reluLayer("Name", "stft_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "stft_conv_3")
layerNormalizationLayer("Name", "stft_layernorm_3")
reluLayer("Name", "stft_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_3")
flattenLayer("Name", "stft_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_stft")
];
cwtBranchLayers = [
cwtLayer("SignalLength", numSamples, "TransformMode", "squaremag", "Name","cwt", "WeightLearnRateFactor", 0);
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="cwt_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "cwt_conv_1")
layerNormalizationLayer("Name", "cwt_layernorm_1")
reluLayer("Name", "cwt_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "cwt_conv_2")
layerNormalizationLayer("Name", "cwt_layernorm_2")
reluLayer("Name", "cwt_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "cwt_conv_3")
layerNormalizationLayer("Name", "cwt_layernorm_3")
reluLayer("Name", "cwt_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_3")
flattenLayer("Name", "cwt_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_cwt")
];
finalLayers = [
attentionLayer(4,"Name","attention")
layerNormalizationLayer("Name","layernorm")
fullyConnectedLayer(48,"Name","fc_1")
fullyConnectedLayer(numel(waveformClasses),"Name","fc_2")
softmaxLayer("Name","softmax")
];
dlLayers2 = dlnetwork(initialLayers);
dlLayers2 = addLayers(dlLayers2, stftBranchLayers);
dlLayers2 = addLayers(dlLayers2, cwtBranchLayers);
dlLayers2 = addLayers(dlLayers2, finalLayers);
dlLayers2 = connectLayers(dlLayers2, "conv1d", "stft");
dlLayers2 = connectLayers(dlLayers2, "conv1d", "cwt");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/key");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/value");
dlLayers2 = connectLayers(dlLayers2,"fc_cwt","attention/query");
my example:(is it right ?)
numChannels = 10;
numObservations = 128;
numTimeSteps = 100;
X = rand(numChannels,numObservations,numTimeSteps);
X = dlarray(X);
Y = rand(numChannels,numObservations,numTimeSteps);
Y = dlarray(Y);
numHeads = 8;
outputSize = numChannels*numHeads;
WQ = rand(outputSize, numChannels, 1, 1);
WK = rand(outputSize, numChannels, 1, 1);
WV = rand(outputSize, numChannels, 1, 1);
WO = rand(outputSize, outputSize, 1, 1);
Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO);
function Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO)
queries = WQ * X;
keys = WK * Y;
values = WV * Y;
A = attention(queries, keys, values, numHeads, 'DataFormat', 'CBT');
Z = WO * A;
end
0 Comments
Accepted Answer
Sahas
on 18 Dec 2024
Edited: Sahas
on 18 Dec 2024
As per my understanding, you would like to replace the dual-branch merge section of the model with cross-attention. I went through your implementation and observed a few things. The implementation looks structurally correct but ensure the following points when using cross-attention with Classification technique as given in the documentation example:
Hope this is beneficial!
More Answers (0)
See Also
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!