ClassWeightsの設定方法
2 views (last 30 days)
Show older comments
深層学習を使用したセマンティック セグメンテーションhttps://jp.mathworks.com/help/releases/R2018a/vision/examples/semantic-segmentation-using-deep-learning.html
をもとに自分で用意したデータセットで解析を行ったところ以下のようなエラーが出ました.
以下にi_learningのコードを示します.
%ネットワークの作成
imageSize = [360 480 3];
numClasses = numel(classes);
lgraph = segnetLayers(imageSize,numClasses,'vgg16');
%クラスの重み付けを使用したクラスのバランス調整
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ imageFreq;
pxLayer = pixelClassificationLayer('Name','labels','ClassNames',tbl.Name,'ClassWeights',classWeights);
lgraph = removeLayers(lgraph,'pixelLabels');
lgraph = addLayers(lgraph, pxLayer);
lgraph = connectLayers(lgraph,'softmax','labels');
%学習オプションの選択
options = trainingOptions('sgdm', ...
'Momentum',0.9, ...
'InitialLearnRate',1e-3, ...
'L2Regularization',0.0005, ...
'MaxEpochs',100, ...
'MiniBatchSize',2, ...
'Shuffle','every-epoch', ...
'VerboseFrequency',2);
%データ拡張
augmenter = imageDataAugmenter('RandXReflection',true,...
'RandXTranslation',[-10 10],'RandYTranslation',[-10 10]);
%学習の開始
pximds = pixelLabelImageDatastore(imdsTrain,pxdsTrain,'DataAugmentation',augmenter);
net= trainNetwork(pximds,lgraph,options);
0 Comments
Accepted Answer
Kenta
on 18 Dec 2019
classWeights
と入力して、それぞれの値を教えてもらえますか?訓練データに、ある稀なラベルが含まれていなくて0で割っている状態なのではないかと思いました。
classWeights = median(imageFreq) ./ (imageFreq+0.0001);
などとすれば回避できると思いました。
10 Comments
More Answers (0)
See Also
Categories
Find more on モデルの作成と評価 in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!