The issue is that when I implement it (line by line, as my first test run) with a custom dataset (or with the recommended dataset), I get an error that says:
"Sampling layer is only for Train a variational autoencoder...".
Which means that the "Sampling layer" in the encoder and the "feature input layer" in the decoder does not exist for me to use.
Is there anyway I can implement a VAE in the current version of Matlab? If so what kind of layers and layer setup should I use?
trainImagesFile = "train-images-idx3-ubyte.gz";
testImagesFile = "t10k-images-idx3-ubyte.gz";
XTrain = processImagesMNIST(trainImagesFile);
XTest = processImagesMNIST(testImagesFile);
numLatentChannels = 16;
imageSize = [28 28 1];
layersE = [
imageInputLayer(imageSize,Normalization="none")
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
fullyConnectedLayer(2*numLatentChannels)
samplingLayer];
projectionSize = [7 7 64];
numInputChannels = size(imageSize,1);
layersD = [
featureInputLayer(numLatentChannels)
projectAndReshapeLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
numEpochs = 30;
miniBatchSize = 128;
learnRate = 1e-3;
dsTrain = arrayDatastore(XTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB", ...
PartialMiniBatch="discard");
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
numObservationsTrain = size(XTrain,4);
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
monitor = trainingProgressMonitor( ...
Metrics="Loss", ...
Info="Epoch", ...
XLabel="Iteration");
epoch = 0;
iteration = 0;
% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, ...
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, ...
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
% Update the training progress monitor.
recordMetrics(monitor,iteration,Loss=loss);
updateInfo(monitor,Epoch=epoch + " of " + numEpochs);
monitor.Progress = 100*iteration/numIterations;
end
end
dsTest = arrayDatastore(XTest,IterationDimension=4);
numOutputs = 1;
mbqTest = minibatchqueue(dsTest,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
err = mean((XTest-YTest).^2,[1 2 3]);
figure
histogram(err)
xlabel("Error")
ylabel("Frequency")
title("Test Data")
numImages = 64;
ZNew = randn(numLatentChannels,numImages);
ZNew = dlarray(ZNew,"CB");
YNew = predict(netD,ZNew);
YNew = extractdata(YNew);
figure
I = imtile(YNew);
imshow(I)
title("Generated Images")
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
[Z,mu,logSigmaSq] = forward(netE,X);
% Forward through decoder.
Y = forward(netD,Z);
% Calculate loss and gradients.
loss = elboLoss(Y,X,mu,logSigmaSq);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
end
function loss = elboLoss(Y,T,mu,logSigmaSq)
% Reconstruction loss.
reconstructionLoss = mse(Y,T);
% KL divergence.
KL = -0.5 * sum(1 + logSigmaSq - mu.^2 - exp(logSigmaSq),1);
KL = mean(KL);
% Combined loss.
loss = reconstructionLoss + KL;
end
function Y = modelPredictions(netE,netD,mbq)
Y = [];
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq);
% Forward through encoder.
Z = predict(netE,X);
% Forward through dencoder.
XGenerated = predict(netD,Z);
% Extract and concatenate predictions.
Y = cat(4,Y,extractdata(XGenerated));
end
end
function X = preprocessMiniBatch(dataX)
% Concatenate.
X = cat(4,dataX{:});
end
function X = processImagesMNIST(filename)
dataFolder = fullfile(tempdir,'mnist');
gunzip(filename,dataFolder)
[~,name,~] = fileparts(filename);
[fileID,errmsg] = fopen(fullfile(dataFolder,name),'r','b');
if fileID < 0
error(errmsg);
end
magicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2049
fprintf('\nRead MNIST label data...\n')
end
numImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of labels in the dataset: %6d ...\n',numImages);
X = fread(fileID,inf,'unsigned char');
X = reshape(X,[1,size(X,1)]);
%X = reshape(X,numCols,numRows,numImages);
%X = permute(X,[2 1 3]);
%X = X./255;
%X = reshape(X, [28,28,1,size(X,3)]);
fclose(fileID);
end

6 Comments

Can you add the code so we can see where exactly the error is happening?
I dont have the code with me anymore. But I followed the link line by line, with the dataset installed in the same directory as the matlab file.
I uploaded the code, for you to check.
Hi,
The layers samplingLayer and projectAndReshapeLayer are custom layers that were written for the example. They won't be automatically in the path, but MATLAB is detecting that they exist in an example folder (which is the error that you see I believe). If you want to use them, you need to make sure the files that define those custom layers are in the path using addpath, or you can just copy the files in your folder.
Can you try this and see if that helps?
I was able to fix it! Thank you for the help!
Great! I'll add this as an answer then

Sign in to comment.

 Accepted Answer

Yoann Roth
Yoann Roth on 8 Feb 2023

0 votes

The layers samplingLayer and projectAndReshapeLayer are custom layers that were written for the example. They won't be automatically in the path, but MATLAB is detecting that they exist in an example folder (which is the error that you see I believe). If you want to use them, you need to make sure the files that define those custom layers are in the path using addpath, or you can just copy the files in your folder.

2 Comments

classdef samplingLayer < nnet.layer.Layer
methods
function layer = samplingLayer(args)
% layer = samplingLayer creates a sampling layer for VAEs.
%
% layer = samplingLayer(Name=name) also specifies the layer
% name.
% Parse input arguments.
arguments
args.Name = "";
end
% Layer properties.
layer.Name = args.Name;
layer.Type = "Sampling";
layer.Description = "Mean and log-variance sampling";
layer.OutputNames = ["out" "mean" "log-variance"];
end
function [Z,mu,logSigmaSq] = predict(~,X)
% [Z,mu,logSigmaSq] = predict(~,Z) Forwards input data through
% the layer at prediction and training time and output the
% result.
%
% Inputs:
% X - Concatenated input data where X(1:K,:) and
% X(K+1:end,:) correspond to the mean and
% log-variances, respectively, and K is the number
% of latent channels.
% Outputs:
% Z - Sampled output
% mu - Mean vector.
% logSigmaSq - Log-variance vector
% Data dimensions.
numLatentChannels = size(X,1)/2;
miniBatchSize = size(X,2);
% Split statistics.
mu = X(1:numLatentChannels,:);
logSigmaSq = X(numLatentChannels+1:end,:);
% Sample output.
epsilon = randn(numLatentChannels,miniBatchSize,"like",X);
sigma = exp(.5 * logSigmaSq);
Z = epsilon .* sigma + mu;
end
end
end
Andrew
Andrew on 15 Feb 2024
classdef projectAndReshapeLayer < nnet.layer.Layer
properties
% (Optional) Layer properties.
OutputSize
end
properties (Learnable)
% Layer learnable parameters.
Weights
Bias
end
methods
function layer = projectAndReshapeLayer(outputSize, numChannels, name)
% Create a projectAndReshapeLayer.
% Set layer name.
layer.Name = name;
% Set layer description.
layer.Description = "Project and reshape layer with output size " + join(string(outputSize));
% Set layer type.
layer.Type = "Project and Reshape";
% Set output size.
layer.OutputSize = outputSize;
% Initialize fully connect weights and bias.
fcSize = prod(outputSize);
layer.Weights = initializeGlorot(fcSize, numChannels);
layer.Bias = zeros(fcSize, 1, 'single');
end
function Z = predict(layer, X)
% Forward input data through the layer at prediction time and
% output the result.
%
% Inputs:
% layer - Layer to forward propagate through
% X - Input data, specified as a 1-by-1-by-C-by-N
% dlarray, where N is the mini-batch size.
% Outputs:
% Z - Output of layer forward function returned as
% an sz(1)-by-sz(2)-by-sz(3)-by-N dlarray,
% where sz is the layer output size and N is
% the mini-batch size.
% Fully connect.
weights = layer.Weights;
bias = layer.Bias;
X = fullyconnect(X,weights,bias,'DataFormat','SSCB');
% Reshape.
outputSize = layer.OutputSize;
Z = reshape(X, outputSize(1), outputSize(2), outputSize(3), []);
end
end
end
function weights = initializeGlorot(numOut, numIn)
% Initialize weights using uniform Glorot.
varWeights = sqrt( 6 / (numIn + numOut) );
weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1);
end

Sign in to comment.

More Answers (0)

Asked:

on 27 Jan 2023

Commented:

on 15 Feb 2024

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!