Vae in Matlab2022b
Show older comments
Basically I want to recreate the VAE from this page: https://au.mathworks.com/help/deeplearning/ug/train-a-variational-autoencoder-vae-to-generate-images.html
The issue is that when I implement it (line by line, as my first test run) with a custom dataset (or with the recommended dataset), I get an error that says:
"Sampling layer is only for Train a variational autoencoder...".
Which means that the "Sampling layer" in the encoder and the "feature input layer" in the decoder does not exist for me to use.
Is there anyway I can implement a VAE in the current version of Matlab? If so what kind of layers and layer setup should I use?

trainImagesFile = "train-images-idx3-ubyte.gz";
testImagesFile = "t10k-images-idx3-ubyte.gz";
XTrain = processImagesMNIST(trainImagesFile);
XTest = processImagesMNIST(testImagesFile);
numLatentChannels = 16;
imageSize = [28 28 1];
layersE = [
imageInputLayer(imageSize,Normalization="none")
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
fullyConnectedLayer(2*numLatentChannels)
samplingLayer];
projectionSize = [7 7 64];
numInputChannels = size(imageSize,1);
layersD = [
featureInputLayer(numLatentChannels)
projectAndReshapeLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
numEpochs = 30;
miniBatchSize = 128;
learnRate = 1e-3;
dsTrain = arrayDatastore(XTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB", ...
PartialMiniBatch="discard");
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
numObservationsTrain = size(XTrain,4);
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
monitor = trainingProgressMonitor( ...
Metrics="Loss", ...
Info="Epoch", ...
XLabel="Iteration");
epoch = 0;
iteration = 0;
% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, ...
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, ...
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
% Update the training progress monitor.
recordMetrics(monitor,iteration,Loss=loss);
updateInfo(monitor,Epoch=epoch + " of " + numEpochs);
monitor.Progress = 100*iteration/numIterations;
end
end
dsTest = arrayDatastore(XTest,IterationDimension=4);
numOutputs = 1;
mbqTest = minibatchqueue(dsTest,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
err = mean((XTest-YTest).^2,[1 2 3]);
figure
histogram(err)
xlabel("Error")
ylabel("Frequency")
title("Test Data")
numImages = 64;
ZNew = randn(numLatentChannels,numImages);
ZNew = dlarray(ZNew,"CB");
YNew = predict(netD,ZNew);
YNew = extractdata(YNew);
figure
I = imtile(YNew);
imshow(I)
title("Generated Images")
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
[Z,mu,logSigmaSq] = forward(netE,X);
% Forward through decoder.
Y = forward(netD,Z);
% Calculate loss and gradients.
loss = elboLoss(Y,X,mu,logSigmaSq);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
end
function loss = elboLoss(Y,T,mu,logSigmaSq)
% Reconstruction loss.
reconstructionLoss = mse(Y,T);
% KL divergence.
KL = -0.5 * sum(1 + logSigmaSq - mu.^2 - exp(logSigmaSq),1);
KL = mean(KL);
% Combined loss.
loss = reconstructionLoss + KL;
end
function Y = modelPredictions(netE,netD,mbq)
Y = [];
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq);
% Forward through encoder.
Z = predict(netE,X);
% Forward through dencoder.
XGenerated = predict(netD,Z);
% Extract and concatenate predictions.
Y = cat(4,Y,extractdata(XGenerated));
end
end
function X = preprocessMiniBatch(dataX)
% Concatenate.
X = cat(4,dataX{:});
end
function X = processImagesMNIST(filename)
dataFolder = fullfile(tempdir,'mnist');
gunzip(filename,dataFolder)
[~,name,~] = fileparts(filename);
[fileID,errmsg] = fopen(fullfile(dataFolder,name),'r','b');
if fileID < 0
error(errmsg);
end
magicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2049
fprintf('\nRead MNIST label data...\n')
end
numImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of labels in the dataset: %6d ...\n',numImages);
X = fread(fileID,inf,'unsigned char');
X = reshape(X,[1,size(X,1)]);
%X = reshape(X,numCols,numRows,numImages);
%X = permute(X,[2 1 3]);
%X = X./255;
%X = reshape(X, [28,28,1,size(X,3)]);
fclose(fileID);
end
6 Comments
Arkadiy Turevskiy
on 31 Jan 2023
Can you add the code so we can see where exactly the error is happening?
Glacial Claw
on 31 Jan 2023
Glacial Claw
on 7 Feb 2023
Yoann Roth
on 7 Feb 2023
Hi,
The layers samplingLayer and projectAndReshapeLayer are custom layers that were written for the example. They won't be automatically in the path, but MATLAB is detecting that they exist in an example folder (which is the error that you see I believe). If you want to use them, you need to make sure the files that define those custom layers are in the path using addpath, or you can just copy the files in your folder.
Can you try this and see if that helps?
Glacial Claw
on 8 Feb 2023
Yoann Roth
on 8 Feb 2023
Great! I'll add this as an answer then
Accepted Answer
More Answers (0)
Categories
Find more on Visualization and Interpretability in Help Center and File Exchange
Products
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!