どのように Stacked Autoencoder でデコード処理までを含ませることができますか?

5 views (last 30 days)
Stacked Autoencoderを用いて入力画像を再復元したいと考えています。
EncoderとDecoderが、ともに複数の層をもつようなネットワークを作成する方法を、教えてください。

Accepted Answer

MathWorks Support Team
MathWorks Support Team on 30 Sep 2020
Edited: MathWorks Support Team on 30 Sep 2020
スタックされたネットワークに対して、 Decoder 側を追加するようなスマートな機能は、Neural Network Toolbox (R2017a) では提供されておりません。
Toolbox で、AutoEncoder の 機能が提供される以前に、 AutoEncoder を実現していた例題が下記(※)があり、カスタムネットワークを作成いただくような形で、進めていただくような方法となります。
※ MathWorks Accont によるログインが必要となりますこと、ご了承ください。
上記については、通常の、 AutoEncoder による「識別」を目的とした例題となっておりますので、カスタムネットワークの定義の全体像の紹介までとなりますが、実際に、スタックされたネットワークに対して、 Decoder 側を追加するような方法としては、次のようなスクリプト例が基本構成となります。
X = abalone_dataset; % 8x4177 データを読み込み
% 入力、隠れ層のユニット数
inp_num = size(X,1); % 入力層
hid_num1 = 7; % 隠れ層 1
hid_num2 = 6 ; % 隠れ層 2
% AutoEncoder 実施
auto1 = trainAutoencoder(X,hid_num1,'DecoderTransferFunction','pureline','Scale',false);
tmp1 = encode(auto1,X);
auto2 = trainAutoencoder(tmp1,hid_num2,'DecoderTransferFunction','purelin','Scale',false);
tmp2 = encode(auto2,tmp1);
% 全体のネットワークを作成
fnet = network;
fnet.numInputs = 1;
fnet.numLayers = 4;
fnet.inputConnect(1,1) = 1;
fnet.layerConnect(2,1) = 1;
fnet.layerConnect(3,2) = 1;
fnet.layerConnect(4,3) = 1;
fnet.outputConnect(1,4) = 1;
fnet.biasConnect = [1;1;1;1];
fnet.inputs{1}.size = inp_num;
fnet.layers{1}.size = hid_num1;
fnet.layers{2}.size = hid_num2;
fnet.layers{3}.size = hid_num1;
fnet.layers{4}.size = inp_num;
fnet.layers{1}.transferFcn = 'logsig';
fnet.layers{2}.transferFcn = 'logsig';
fnet.layers{3}.transferFcn = 'purelin';
fnet.layers{4}.transferFcn = 'purelin';
fnet.divideFcn = 'dividetrain';
fnet.IW{1,1} = auto1.EncoderWeights;
fnet.b{1} = auto1.EncoderBiases;
fnet.LW{2,1} = auto2.EncoderWeights;
fnet.b{2} = auto2.EncoderBiases;
fnet.LW{3,2} = auto2.DecoderWeights;
fnet.b{3} = auto2.DecoderBiases;
fnet.LW{4,3} = auto1.DecoderWeights;
fnet.b{4} = auto1.DecoderBiases;
fnet.trainFcn = 'trainscg';
fnet.trainParam.epochs = 1000;
fnet = train(fnet,X,X);
Y = fnet(X);
hold on
for n = 1:inp_num
plot(X(n,:),Y(n,:),'.')
end
grid on
plot([-0.5 3.5],[-0.5 3.5])

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!