How to implement PyTorch's Linear layer in Matlab?
Show older comments
Hello,
The problem is that Linear does not flatten its inputs whereas Matlab's fullyConnectedLayer does, so the two are not equivalent.
Thx,
J
Answers (4)
One possibility might be to express the linear layer as a cascade of fullyConnectedLayer followed by a functionLayer. The functionLayer can reshape the flattened input back to the form you want,
layer = functionLayer(@(X)reshape(X,[h,w,c]));
9 Comments
John Smith
on 12 Feb 2023
Edited: John Smith
on 12 Feb 2023
So you're saying you want a separate linear transform applied independently to each of the channels? Then you could use a groupedConvolution2dLayer, where the filtersize is the size of a complete channel,
layer = groupedConvolution2dLayer(filterSize,numFiltersperGroup,'channel-wise')
and numFiltersperGroup is chosen depending on how many outputs your linear transforms are supposed to have.
John Smith
on 13 Feb 2023
The 2d-convolution performs element-wise multiplication of the kernel with the input and sums all the intermediate results together which is not what matrix multiplication does.
Yes it is. That's exactly what multiplication does.
The kernel would need to be duplicated per channel and then the issue of divergence during training still might bite.
You'll need to elaborate on that. Are you trying to transform each channel with the same matrix or avoid it? The scheme I outlined for you applies a different matrix to each channel, but if you want to apply the same matrix, you can reshape the input X (assume dimensions are HxWxC) so that the channel dimension becomes spatial,
X=reshape(X,[H*W,C])
and then apply a conv2dLayer with an (H*W)x1xN filter with no padding. It has the exact same effect. as multiplying each column (i.e., channel) of X with an Nx(H*W) matrix.
John Smith
on 13 Feb 2023
We need to the result of two matrices multiplication, which is another matrix.
Yes, that is what my suggestion gives you. There is no limit in what I propose to the shape or size of the input and output.
More generally, there is no linear transform that can't be implemented using conv layers in combination with reshape() and permute() functionLayers. The only thing that is lacking is a clear understanding of where you want the transformation data to be re-used, if at all. My current understanding is that you want it to be re-used channel-wise. In other words, all channels are to be subject to a common linear transform.
John Smith
on 13 Feb 2023
Edited: John Smith
on 13 Feb 2023
This solution sums all channels together.
No, it won't. (Keep in mind that this is the 3rd solution I've proposed as information about your aims has come out). After the reshaping, each channel is contained in its own column of X. And, because the filter you apply to X is (H*W)x1xN there is no way for the filter to combine elements from different columns.
John Smith
on 13 Feb 2023
John Smith
on 13 Feb 2023
Edited: John Smith
on 13 Feb 2023
Another possible way to interpret your question is that you are trying to apply pagemtimes to the input X with a non-learnable matrix A, where the different channels of X are the pages. That can also be done with a functionLayer, as illustrated below both with normal arrays and with dlarrays,
A=rand(4,3); %non-learnable matrix A
xdata=rand(3,3,2); %input layer data with 2 channels
multLayer=functionLayer(@(X) dlarray( pagemtimes(A,stripdims(X)) ,dims(X)) );
X=dlarray(xdata,'SSC');
Y=multLayer.predict(X)
%%Verify agreement with normal pagemtimes
ydata=pagemtimes(A,xdata)
3 Comments
John Smith
on 13 Feb 2023
The modification for the case where A is learnable is as below. I am using a pre-declared A here only so that I can demonstrate and test the response. In a real scenario, you wouldn't supply weights to the convolution2dLayer.
X=dlarray(rand(3,3,2),'SSC'); A=rand(4,3);
[h,w,c]=size(X);
L1=functionLayer( @(z) z(:,:) );
Lconv=convolution2dLayer([h,1],4,'Weights',permute(A,[2,3,4,1]));
L2=functionLayer(@(z)recoverShape(z,w,c) ,'Formattable',1);
net=dlnetwork([L1,Lconv,L2],X);
Yfinal=net.predict(X)
And as before, we can compare to the result of a plain-vanilla pagemtimes operation and see that it gives the same result:
Ycheck=pagemtimes(A, extractdata(X))
function out=recoverShape(z,w,c)
z=permute( stripdims(z), [3,2,1]);
out=dlarray(reshape(z,[],w,c),'SSC');
end
John Smith
on 14 Feb 2023
Edited: John Smith
on 14 Feb 2023
Another approach is to write your own custom layer for channel-wise matrix multiplication. I have attached a possible version of this,
X=rand(3,3,2);
L=pagemtimesLayer(4); %Custom layer - premultiplies channels by 4-row learnable matrix A
L=initialize(L, X);
Ypred=L.predict(X)
Ycheck=pagemtimes(L.A,X) %Check agreement with a direct call to pagemtimes()
8 Comments
John Smith
on 14 Feb 2023
Edited: John Smith
on 14 Feb 2023
Matt J
on 14 Feb 2023
Most of that you can do by using tensorprod instead of pagemtimes. I leave it to you to tweak the design to your liking.
John Smith
on 14 Feb 2023
That is unfortunate. However, adding support for dlarrays wouldn't be terribly onerous. Instead of relying on automatic differentiation, you would have to add a backward() method to the class,
function Z = predict(layer, X)
if isa(X,'dlarray')
X=extractdata(X);
end
Z = tensorprod(layer.A,X,layer.innerdim) + layer.b;
end
function [dLdX,dLdA,dLdb] = backward(layer,X,Z,dLdZ)
dZdX=layer.A;
dZdA=X; %X should really be transposed here in some tensorial sense.
%Or, the tensorprod dimensions below should be adjusted.
dLdX=tensorprod(dLdZ,dZdX,___);
dLdA=tensorprod(dLdZ,dZdA,___);
dLdb=dLdZ;
....
end
which should be easy for a simple linear operation. See also,
John Smith
on 15 Feb 2023
Edited: John Smith
on 15 Feb 2023
If your bias is a scalar, uniformly applied to all Z(i), then yes. But if you have a different bias for every Z(i) then, extrapolating from the example below, dZdb should be an identity operator, meaning that dLdb=dLdZ*dZdb=dLdZ. Granted, I haven't tested any of what I'm outlining.
syms x b [4,1]; syms A [4,4]
Z=A*x+b
dZdb = jacobian(Z,b)
John Smith
on 15 Feb 2023
Edited: John Smith
on 15 Feb 2023
Matt J
on 15 Feb 2023
That sounds right.
Although, part of me questions whether it was the best design for TMW to make the the user responsible for summing over batched input in the backward() method, since that dimension should always be handled the same way.
Categories
Find more on Operations in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!













