What is the fastest workaround for doing a 3D channel-wise separable convolution using dlconv()?

4 views (last 30 days)
Hello.
I'm constructing a custom convolutional layer in MATLAB R2021a. Currently, I have an input dlarray of size 256x256x64x3x20 (X), with the format 'SSCSB', where S=Spatial, C=Channel and B=Batch. I also have a set of (Gabor) filters, that I've generated using an external function, of size 3x3x64x3 (K) with the format 'SSCS'. I want to perform a 3D convolution on each 256x256x1x3x1 image (of which there are 20*64 in each batch), using the corresponding 3x3x1x3 filter. I thus basically want to perform 64 separate 3D convolutions on 256x256x3 images using a different 3x3x3 filter for each such image, resulting in a 256x256x1 image in each case. This shuld be done once for each sample of the batch. The result should be a 256x256x64x20 array.
As far as my understanding of the dlconv-function goes, this convolution should be written using the following code, that performs a channel-wise separable convolution for each of the 64 "channels":
X: 256x256x64x3x20 = SxSxCxSxB
K: 3x3x1x1x64x3 = SxSxCxU1xU2xS where C is the "Number of channels per group", U1 is the "Number of filters per group", U2 is "Number of groups" (note that I've added the singular dimensions for formatting reasons according to https://se.mathworks.com/help/deeplearning/ref/dlarray.dlconv.html ).
result = dlconv(X, K, bias, 'Stride', [1 1], 'Padding', 'same', 'DataFormat', 'SSCSB', 'WeightsFormat', 'SSCUUS');
However, this does not work. I get an error message saying: "Convolution with two or more groups does not support convolving over three or more dimensions."
My current workaround is to run the convolution separately for each of the 64 channels using a for-loop, defining the 3:rd spatial dimension as the channel dimension. The current code for the workaround is:
X: 256x256x64x3x20
K: 3x3x64x3
for filt = 1:64
X(:, :, filt, 1, :) = dlconv(X(:, :, filt, :, :), K(:, :, filt, :), layer.bias, ...
'Stride', [1 1], 'Padding', 'same', 'DataFormat', 'SSUCB', 'WeightsFormat', 'SSUC');
end
Z = squeeze(result(:, :, :, 1, :))
However, this causes the convolution to take a significant amount of time to perform, especially for the backward pass.
My question is: Is there any computationally faster way to do this workaround? Is there a solution that doesn't have to involve looping?

Answers (0)

Products


Release

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!