ニューラルネットワー​クの学習をdoubl​e型で行うことはでき​ますか?

4 views (last 30 days)
Fumiya Watanabe
Fumiya Watanabe on 26 Jun 2018
Commented: Fumiya Watanabe on 5 Jul 2018
ニューラルネットワークの学習をdouble型で行うことはできますか?
現在、ある実数値ベクトルを入力とする回帰問題をNeural Network Toolboxを用いて実現しようとしています。 このベクトル入力を画像入力として扱うことで実現を考えています。しかしながら、trainNetworkを実行するとsingle型として扱われてしまう問題が生じており、解決法がわからず困っております。
例えば、次の自作の回帰層を考えます。
classdef testLayer < nnet.layer.RegressionLayer
methods
function layer = testLayer()
end
function loss = forwardLoss(layer, Y, T)
loss = gpuArray(0);
end
function dLdX = backwardLoss(layer, Y, T)
dLdX = gpuArray(zeros(size(Y)));
end
end
end
この自作回帰層を用いて、次のように学習を実行します。
%%学習データ
x_in = rand(10, 1, 1, 6);
y_tr = rand(6, 5);
%%層構造とオプションの定義
layers = [
imageInputLayer([10 1 1], 'Normalization', 'none', 'Name', 'Input')
fullyConnectedLayer(2, 'Name', 'Layer1')
reluLayer('Name', 'ReLU1')
fullyConnectedLayer(5, 'Name', 'Output')
testLayer
];
layers(end).Name = 'Regression';
options = trainingOptions(...
'sgdm',...
'InitialLearnRate', 0.001, ...
'MiniBatchSize', 3, ...
'MaxEpochs', 1);
%%学習開始
net = trainNetwork(x_in, y_tr, layers, options);
すると、次のエラーが発生します。
エラー: trainNetwork (line 154)
Incorrect type of dLdX for 'backwardLoss' in the output layer. Expected gpuArray of underlying type 'single', but instead has
underlying type 'double'.
上記の自作回帰層で、gpuArrayの内部をsingleにキャストすることで実行することが可能となるのですが、実際に使っている自作回帰層ではdouble型でないと計算できない関数を利用しているため、
function loss = forwardLoss(layer, Y, T)
loss = gpuArray(single(myfun(double(Y), double(T))));
end
のようなキャストをしていく必要が生じてしまいます。これを避けるために学習をdouble型で実行したいのですが、解決法はありますでしょうか。

Accepted Answer

Naoya
Naoya on 29 Jun 2018
Neural Network Toolbox で提供される 畳み込みニューラルネットワークですが、trainNetwork 側で与えるデータ型は single, double 両方を受け付けます。
しかしながら、基本的にGPU上では単精度演算として扱われますので、GPU へ渡すゲートウェイとなるデータ型は single型となってしまいます。
  3 Comments
Naoya
Naoya on 3 Jul 2018
ご連絡ありがとうございます。 cpuモードの場合でも backwardLoss 関数のゲートウェイは single型にする必要があります。
Fumiya Watanabe
Fumiya Watanabe on 5 Jul 2018
ご回答ありがとうございます。
入力としてはdouble型を受け付けるが、計算内部はGPU・CPUどちらの場合でもsingle型で実行される形になっており、自作の層を扱う場合はsingle型でほかの層とのやり取りが必要であると理解いたしました。 ありがとうございました。

Sign in to comment.

More Answers (0)

Products


Release

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!