Main Content

Train Network Using Federated Learning

This example shows how to train a network using federated learning. Federated learning is a technique that enables you to train a network in a distributed, decentralized way [1].

Federated learning allows you to train a model using data from different sources without moving the data to a central location, even if the individual data sources do not match the overall distribution of the data set. This is known as non-independent and identically distributed (non-IID) data. Federated learning can be especially useful when the training data is large, or when there are privacy concerns about transferring the training data.

Instead of distributing data, the federated learning technique trains multiple models, each in the same location as a data source. You can create a global model that has learned from all the data sources by periodically collecting and combining the learnable parameters of the locally trained models. In this way, you can train a global model without centrally processing any training data.

This example uses federated learning to train a classification model in parallel using a highly non-IID dataset. The model is trained using the digits data set, which consists of 10000 handwritten images of the numbers 0 to 9. The example runs in parallel using 10 workers, each processing images of a single digit. By averaging the learnable parameters of the networks after each round of training, the models on each worker improve performance across all classes, without ever processing data of the other classes.

While data privacy is one of the applications of federated learning, this example does not deal with the details of maintaining data privacy and security. This example demonstrates the basic federated learning algorithm.

Set Up Parallel Environment

Create a parallel pool with the same number of workers as classes in the data set. For this example, use a process-based, local parallel pool with 10 workers.

cluster = parcluster("Processes");
cluster.NumWorkers = 10;
pool = parpool(cluster);
Starting parallel pool (parpool) using the 'Processes' profile ...
Connected to parallel pool with 10 workers.
numWorkers = pool.NumWorkers;

Load Data Set

All data used in this example is initially stored in a centralized location. To make this data highly non-IID, you need to distribute the data among the workers according to class. To create validation and test data sets, transfer a portion of data from the workers to the client. After the data is correctly set up, with training data of individual classes on the workers and test and validation data of all classes on the client, there is no further transfer of data during training.

Specify the folder containing the image data.

digitDatasetPath = fullfile(matlabroot,"toolbox","nnet","nndemos", ...

Distribute the data among the workers. Each worker receives images of only one digit, such that worker 1 receives all the images of the number 0, worker 2 receives images of the number 1, etc.

Images of each digit are stored in a separate folder with the name of that digit. On each worker, use the fullfile function to specify the path to a specific class folder. Then, create an imageDatastore that contains all images of that digit. Next, use the splitEachLabel function to randomly separate 30% of the data for use in validation and testing. Finally, create an augmentedImageDatastore containing the training data.

inputSize = [28 28 1];
    digitDatasetPath = fullfile(digitDatasetPath,num2str(spmdIndex - 1));
    imds = imageDatastore(digitDatasetPath, ...
        IncludeSubfolders=true, ...
    [imdsTrain,imdsTestVal] = splitEachLabel(imds,0.7,"randomized");
    augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);

To test the performance of the combined global model during and after training, create test and validation datasets containing images from all classes. Combine the test and validation data from each worker into a single datastore. Then, split this datastore into two datastores that each contain 15% of the overall data - one for validating the network during training and the other for testing the network after training.

fileList = [];
labelList = [];

for i = 1:numWorkers
    tmp = imdsTestVal{i};
    fileList = cat(1,fileList,tmp.Files);
    labelList = cat(1,labelList,tmp.Labels);    

imdsGlobalTestVal = imageDatastore(fileList);
imdsGlobalTestVal.Labels = labelList;

[imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized");

augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest);
augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);

The data is now arranged such that each worker has data from a single class to train on, and the client holds validation and test data from all classes.

Define Network

Determine the number of classes in the data set.

classes = categories(imdsGlobalTest.Labels);
numClasses = numel(classes);

Define the network architecture.

layers = [

Create a dlnetwork object from the layers.

net = dlnetwork(layers)
net = 
  dlnetwork with properties:

         Layers: [9×1 nnet.cnn.layer.Layer]
    Connections: [8×2 table]
     Learnables: [6×3 table]
          State: [0×3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

Define Model Loss Function

Create the function modelLoss, listed in the Model Loss Function section of this example, that takes a dlnetwork object and a mini-batch of input data with corresponding labels and returns the loss and the gradients of the loss with respect to the learnable parameters in the network.

Define Federated Averaging Function

Create the function federatedAveraging, listed in the Federated Averaging Function section of this example, that takes the learnable parameters of the networks on each worker and the normalization factor for each worker, and returns the averaged learnable parameters across all the networks. Use the average learnable parameters to update the global network and the network on each worker.

Define Compute Accuracy Function

Create the function computeAccuracy, listed in the Compute Accuracy Function section of this example, that takes a dlnetwork object, a data set inside a minibatchqueue object, and the list of classes, and returns the accuracy of the predictions across all observations in the data set.

Specify Training Options

During training, the workers periodically communicate their network learnable parameters to the client, so that the client can update the global model. The training is divided into rounds. At the end of each round of training, the learnable parameters are averaged and the global model is updated. The worker models are then replaced with the new global model, and training continues on the workers.

Train for 300 rounds, with 5 epochs per round. Training for a small number of epochs per round ensures that the networks on the workers do not diverge too far before they are averaged.

numRounds = 300;
numEpochsperRound = 5;
miniBatchSize = 100;

Specify the options for SGDM optimization. Specify an initial learn rate of 0.001 and momentum 0.

learnRate = 0.001;
momentum = 0;

Train Model

Create a function handle to the custom mini-batch preprocessing function preprocessMiniBatch (defined in the Mini-Batch Preprocessing Function section of this example).

On each worker, find the total number of training observations processed locally on that worker. Use this number to normalize the learnable parameters on each worker when you find the average learnable parameters after each communication round. This helps to balance the average if there is a difference between the amount of data on each worker.

On each worker, create a minibatchqueue object that processes and manages mini-batches of images during training. For each mini-batch:

  • Preprocess the data using the custom mini-batch preprocessing function preprocessMiniBatch to convert the labels to one-hot encoded variables.

  • Format the image data with the dimension labels 'SSCB' (spatial, spatial, channel, batch). By default, the minibatchqueue object converts the data to dlarray objects with underlying type single. Do not add a format to the class labels.

  • Train on a GPU if one is available. By default, the minibatchqueue object converts each output to a gpuArray if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

preProcess = @(x,y)preprocessMiniBatch(x,y,classes);

    sizeOfLocalDataset = augimdsTrain.NumObservations;
    mbq = minibatchqueue(augimdsTrain, ...
        MiniBatchSize=miniBatchSize, ...
        MiniBatchFcn=preProcess, ...

Create a minibatchqueue object that manages the validation data to use during training. Use the same settings as the minibatchqueue on each worker.

mbqGlobalVal = minibatchqueue(augimdsGlobalVal, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=preProcess, ...

Initialize the trainingProgressMonitor object. Because the timer starts when you create the monitor, make sure that you create the object close to the training loop.

monitor = trainingProgressMonitor( ...
    Metrics="GlobalAccuracy", ...
    Info="CommunicationRound", ...
    XLabel="Communication Round");

Initialize the velocity parameter for the SGDM solver.

velocity = [];

Initialize the global model. To start, the global model has the same initial parameters as the untrained network on each worker.

globalModel = net;

Train the model using a custom training loop. For each communication round,

  • Update the networks on the workers with the latest global network.

  • Train the networks on the workers for five epochs.

  • Find the average parameters of all the networks using the federatedAveraging function.

  • Replace the global network parameters with the average value.

  • Calculate the accuracy of the updated global network using the validation data.

  • Update the global accuracy in the training progress monitor.

  • Stop if the Stop property is true. The Stop property value of the TrainingProgressMonitor object changes to true when you click the Stop button.

For each epoch, shuffle the data and loop over mini-batches of data. For each mini-batch:

  • Evaluate the model loss and gradients using the dlfeval and modelLoss functions.

  • Update the local network parameters using the sgdmupdate function.

round = 0;
while round < numRounds && ~monitor.Stop

    round = round + 1;

        % Send global updated parameters to each worker.
        net.Learnables.Value = globalModel.Learnables.Value;

        % Loop over epochs.
        for epoch = 1:numEpochsperRound
            % Shuffle data.

            % Loop over mini-batches.
            while hasdata(mbq)

                % Read mini-batch of data.
                [X,T] = next(mbq);

                % Evaluate the model loss and gradients using dlfeval and the
                % modelLoss function.
                [loss,gradients] = dlfeval(@modelLoss,net,X,T);

                % Update the network parameters using the SGDM optimizer.
                [net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum);


        % Collect updated learnable parameters on each worker.
        workerLearnables = net.Learnables.Value;

    % Find normalization factors for each worker based on ratio of data
    % processed on that worker.
    sizeOfAllDatasets = sum([sizeOfLocalDataset{:}]);
    normalizationFactor = [sizeOfLocalDataset{:}]/sizeOfAllDatasets;

    % Update the global model with new learnable parameters, normalized and
    % averaged across all workers.
    globalModel.Learnables.Value = federatedAveraging(workerLearnables,normalizationFactor);

    % Calculate the accuracy of the global model.
    accuracy = computeAccuracy(globalModel,mbqGlobalVal,classes);

    % Update the training progress monitor.
    updateInfo(monitor,CommunicationRound=round + " of " + numRounds);
    monitor.Progress = 100*round/numRounds;


After the final round of training, update the network on each worker with the final average learnable parameters. This is important if you want to continue to use or train the network on the workers.

    net.Learnables.Value = globalModel.Learnables.Value;

Test Model

Test the classification accuracy of the model by comparing the predictions on the test set with the true labels.

Create a minibatchqueue object that manages the test data. Use the same settings as the minibatchqueue objects used during training and validation.

mbqGlobalTest = minibatchqueue(augimdsGlobalTest, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=preProcess, ...

Use the computeAccuracy function to compute the predicted classes and calculate the accuracy of the predictions across all the test data.

accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes)
accuracy = single

After you are done with your computations, you can delete your parallel pool. The gcp function returns the current parallel pool object so you can delete the pool.


Model Loss Function

The modelLoss function takes a dlnetwork object net, a mini-batch of input data X with corresponding labels T and returns the loss and the gradients of the loss with respect to the learnable parameters in net. To compute the gradients automatically, use the dlgradient function. To compute predictions of the network during training, use the forward function.

function [loss,gradients] = modelLoss(net,X,T)

    YPred = forward(net,X);
    loss = crossentropy(YPred,T);
    gradients = dlgradient(loss,net.Learnables);


Compute Accuracy Function

The computeAccuracy function takes a dlnetwork object net, a minibatchqueue object mbq, and the list of classes, and returns the accuracy of all the predictions on the data set provided. To compute predictions of the network during validation or after training is finished, use the predict function.

function accuracy = computeAccuracy(net,mbq,classes)

    correctPredictions = [];
    while hasdata(mbq)
        [XTest,TTest] = next(mbq);
        TTest = onehotdecode(TTest,classes,1)';
        YPred = predict(net,XTest);
        YPred = onehotdecode(YPred,classes,1)';
        correctPredictions = [correctPredictions; YPred == TTest];
    predSum = sum(correctPredictions);
    accuracy = single(predSum./size(correctPredictions,1));


Mini-Batch Preprocessing Function

The preprocessMiniBatch function preprocesses the data using the following steps:

  1. Extract the image data from the incoming cell array and concatenate into a numeric array. Concatenating the image data over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.

  2. Extract the label data from the incoming cell arrays and concatenate into a categorical array along the second dimension.

  3. One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.

function [X,Y] = preprocessMiniBatch(XCell,YCell,classes)

    % Concatenate.
    X = cat(4,XCell{1:end});
    % Extract label data from cell and concatenate.
    Y = cat(2,YCell{1:end});
    % One-hot encode labels.
    Y = onehotencode(Y,1,ClassNames=classes);


Federated Averaging Function

The federatedAveraging function takes the learnable parameters of the networks on each worker and the normalization factor for each worker, and returns the averaged learnable parameters across all the networks. Use the average learnable parameters to update the global network and the network on each worker.

function learnables = federatedAveraging(workerLearnables,normalizationFactor)

    numWorkers = size(normalizationFactor,2);
    % Initialize container for averaged learnables with same size as existing
    % learnables. Use learnables of first worker network as an example.
    exampleLearnables = workerLearnables{1};
    learnables = cell(height(exampleLearnables),1);
    for i = 1:height(learnables)   
        learnables{i} = zeros(size(exampleLearnables{i}),"like",(exampleLearnables{i}));
    % Add the normalized learnable parameters of all workers to
    % calculate average values.
    for i = 1:numWorkers
        tmp = workerLearnables{i};
        for values = 1:numel(learnables)
            learnables{values} = learnables{values} + normalizationFactor(i).*tmp{values};


[1] McMahan, H. Brendan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Agüera y Arcas. "Communication-Efficient Learning of Deep Networks from Decentralized Data." Preprint, submitted. February, 2017.

See Also

| | | | | |

Related Topics