Run Custom Training Loops on a GPU and in Parallel
You can speed up your custom training loops by running on a GPU, in parallel using multiple GPUs, or on a cluster.
It is recommended to train using a GPU or multiple GPUs. Only use single CPU or multiple CPUs if you do not have a GPU. CPUs are normally much slower that GPUs for both training and inference. Running on a single GPU typically offers much better performance than running on multiple CPU cores.
Note
This topic shows you how to perform custom training on GPUs, in parallel, and on
the cloud. To learn about parallel and GPU workflows using the trainNetwork
function, see:
Using a GPU or parallel options requires Parallel Computing Toolbox™. Using a GPU also requires a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Using a remote cluster also requires MATLAB® Parallel Server™.
Train Network on GPU
By default, custom training loops run on the CPU. Automatic differentiation using
dlgradient
and dlfeval
supports running on the GPU when your data is on the GPU. To run a custom training
loop on a GPU, simply convert your data to gpuArray
(Parallel Computing Toolbox) during training.
You can use minibatchqueue
to manage your data during training.
minibatchqueue
automatically prepares data for training,
including custom preprocessing and converting data to dlarray
and
gpuArray
. By default, minibatchqueue
returns
all mini-batch variables on the GPU if one is available. You can choose which
variables to return on the GPU using the OutputEnvironment
property.
For an example showing how to use minibatchqueue
to train on the GPU, see Train Network Using Custom Training Loop.
Alternatively, you can manually convert your data to gpuArray
within the training loop.
To easily specify the execution environment, create the variable executionEnvironment
that contains either "cpu"
, "gpu"
, or "auto"
.
executionEnvironment = "auto"
During training, after reading a mini-batch, check the execution environment option and
convert the data to a gpuArray
if necessary. The canUseGPU
function checks for useable
GPUs.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" X = gpuArray(X); end
Train Single Network in Parallel
When you train in parallel, each worker trains the network simultaneously using a portion of a mini-batch. This means that you must combine the gradients, loss, and any state parameters after each iteration, according to the proportion of the mini-batch processed by each worker.
You can train in parallel on your local machine, or on a remote cluster, for example, in the cloud. Start a parallel pool in the desired resources and partition your data between the workers. During training, combine the gradients, loss, and state after each iteration so that the learnable parameters on each worker update in synchronization. For an example showing how to perform custom training in parallel, see Train Network in Parallel with Custom Training Loop
Set Up Parallel Environment
It is recommended to train using a GPU or multiple GPUs. Only use single CPU or multiple CPUs if you do not have a GPU. CPUs are normally much slower that GPUs for both training and inference. Running on a single GPU typically offers much better performance than running on multiple CPU cores.
Set up the parallel environment that you want to use before training. Start a parallel pool using your desired resources. For training using multiple GPUs, start a parallel pool with as many workers as available GPUs. For best performance, MATLAB automatically assigns a different GPU to each worker.
If you are using your local machine, you can use canUseGPU
and gpuDeviceCount
(Parallel Computing Toolbox) to determine if
you have GPUs available. For example, to check availabilities of GPUs and start
a parallel pool with as many workers as available GPUs, use the following
code:
if canUseGPU executionEnvironment = "gpu"; numberOfGPUs = gpuDeviceCount("available"); pool = parpool(numberOfGPUs); else executionEnvironment = "cpu"; pool = parpool; end
If you are running using a remote cluster, for example, a cluster in the cloud, start a parallel pool with as many workers as the number of GPUs per machine multiplied by the number of machines.
For more information on selecting specific GPUs, see Select Particular GPUs to Use for Training.
Specify Mini-Batch Size and Partition Data
Specify the mini-batch size that you want to use during training. For GPU training, a recommended practice is to scale up the mini-batch size linearly with the number of GPUs, in order to keep the workload on each GPU constant. For example, if you are training on a single GPU using a mini-batch size of 64, and you want to scale up to training with four GPUs of the same type, you can increase the mini-batch size to 256 so that each GPU processes 64 observations per iteration.
You can use the following code to scale up the mini-batch size by the number
of workers, where N
is the number of workers in your parallel
pool.
if executionEnvironment == "gpu" miniBatchSize = miniBatchSize .* N end
If you want to use a mini-batch size that not exactly divisible by the number of workers in your parallel pool, then distribute the remainder across the workers.
workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N)); remainder = miniBatchSize - sum(workerMiniBatchSize); workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)]
At the start of training, shuffle your data. Partition your data so that each
worker has access to a portion of the mini-batch. To partition a datastore, use
the partition
function.
You can use minibatchqueue
to manage the data on each worker during training.
minibatchqueue
automatically prepares data for training,
including custom preprocessing and converting data to dlarray
and gpuArray
. Create a minibatchqueue
on each
worker using the partitioned datastore. Set the
MiniBatchSize
property using the mini-batch sizes
calculated for each worker.
At the start of each training iteration, use the gop
(Parallel Computing Toolbox) function to check that all
worker minibatchqueue
objects can return data. If any worker
runs out of data, training stops. If your overall mini-batch size is not exactly
divisible by the number of workers and you do not discard partial mini-batches,
some workers might run out of data before others.
Write your training code inside an spmd
(Parallel Computing Toolbox) block, so that the
training loop executes on each worker.
spmd % Reset and shuffle the datastore. reset(augimdsTrain); augimdsTrain = shuffle(augimdsTrain); % Partition datastore. workerImds = partition(augimdsTrain,N,spmdIndex); % Create minibatchqueue using partitioned datastore on each worker workerMbq = minibatchqueue(workerImds,... "MiniBatchSize",workerMiniBatchSize(spmdIndex),... "MiniBatchFcn",@preprocessMiniBatch); ... for epoch = 1:numEpochs % Reset and shuffle minibatchqueue on each worker. shuffle(workerMbq); % Loop over mini-batches. while gop(@and,hasdata(workerMbq)) % Custom training loop ... end ... end end
Aggregate Gradients
To ensure that the network on each worker learns from all data and not just the data on that worker, aggregate the gradients and use the aggregated gradients to update the network on each worker.
For example, suppose you are training the network net
,
using the model loss function modelLoss
. Your training loop
contains the following code for evaluating the loss, gradients, and statistics
on each
worker:
[workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerT);
workerX
and workerT
are the predictor and true response on each
worker, respectively.To aggregate the gradients, use a weighted sum. Define a helper function to sum the gradients.
function gradients = aggregateGradients(gradients,factor) gradients = extractdata(gradients); gradients = gplus(factor*gradients); end
Inside the training loop, use dlupdate
to apply the function to the gradients of each
learnable
parameter.
workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
Aggregate Loss and Accuracy
To find the network loss and accuracy, for example, to plot them during
training to monitor training progress, aggregate the values of the loss and
accuracy on all of the workers. Typically, the aggregated value is the sum of
the value on each worker, weighted by the proportion of the mini-batch used on
each worker. To aggregate the losses and accuracy each iteration, calculate the
weight factor for each worker and use gplus
(Parallel Computing Toolbox) to sum the values on
each
worker.
workerNormalizationFactor = workerMiniBatchSize(spmdIndex)./miniBatchSize; loss = gplus(workerNormalizationFactor*extractdata(dlworkerLoss)); accuracy = gplus(workerNormalizationFactor*extractdata(dlworkerAccuracy));
Aggregate Statistics
If your network contains layers that track the statistics of your training data, such as batch normalization layers, then you must aggregate the statistics across all workers after each training iteration. Doing so ensures that the network learns statistics that are representative of the entire training set.
You can identify the layers that contain statistics information before
training. For example, if you are using a dlnetwork
with batch
normalization layers, you can use the following code to find the relevant
layers.
batchNormLayers = arrayfun(@(l)isa(l,'nnet.cnn.layer.BatchNormalizationLayer'),net.Layers); batchNormLayersNames = string({net.Layers(batchNormLayers).Name}); state = net.State; isBatchNormalizationStateMean = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedMean"; isBatchNormalizationStateVariance = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedVariance";
N is the total number of workers, M is the total number of observations in a mini-batch, mj is the number of observations processed on the jth worker, and are the mean and variance statistics calculated on that worker, and is the aggregated mean across all workers.
function state = aggregateState(state,factor,... isBatchNormalizationStateMean,isBatchNormalizationStateVariance) stateMeans = state.Value(isBatchNormalizationStateMean); stateVariances = state.Value(isBatchNormalizationStateVariance); for j = 1:numel(stateMeans) meanVal = stateMeans{j}; varVal = stateVariances{j}; % Calculate combined mean combinedMean = gplus(factor*meanVal); % Calculate combined variance terms to sum varTerm = factor.*(varVal + (meanVal - combinedMean).^2); % Update state stateMeans{j} = combinedMean; stateVariances{j} = gplus(varTerm); end state.Value(isBatchNormalizationStateMean) = stateMeans; state.Value(isBatchNormalizationStateVariance) = stateVariances; end
Inside the training loop, use the helper function to update the state of the batch normalization layers with the combined mean and variance.
net.State = aggregateState(workerState,workerNormalizationFactor,...
isBatchNormalizationStateMean,isBatchNormalizationStateVariance);
Plot Results During Training
If you want to plot results during training, you can send data from the
workers to the client using a DataQueue
object.
To easily specify that the plot should be on or off, create the variable
plots
that contains either
"training-progress"
or
"none"
.
plots = "training-progress";
Before training, initialize the DataQueue
and the animated
line using the animatedline
function.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
DataQueue
object. Use afterEach
to
call the helper function displayTrainingProgress
each time
data is sent from the worker to the client.
Q = parallel.pool.DataQueue; displayFcn = @(x) displayTrainingProgress(x,lineLossTrain); afterEach(Q,displayFcn);
displayTrainingProgress
helper function contains the
code used to add points to the animated line and display the training epoch and
duration.function displayTrainingProgress (data,line) addpoints(line,double(data(3)),double(data(2))) D = duration(0,0,data(4),'Format','hh:mm:ss'); title("Epoch: " + data(1) + ", Elapsed: " + string(D)) drawnow end
Inside the training loop, at the end of each epoch, use the
DataQueue
to send the training data from the workers to the
client. At the end of each iteration, the aggregated loss is the same on each
worker, so you can send data from a single
worker.
% Display training progress information. if spmdIndex == 1 data = [epoch loss iteration toc(start)]; send(Q,gather(data)); end
Train Multiple Networks in Parallel
To train multiple networks in parallel, start a parallel pool in your desired
resources and use parfor
(Parallel Computing Toolbox) to train a single network
on each worker.
You can run locally or using a remote cluster. Using a remote cluster requires MATLAB Parallel Server. For more information about managing cluster resources, see Discover Clusters and Use Cluster Profiles (Parallel Computing Toolbox). If you have multiple GPUs and want to exclude some from training, you can choose the GPUs you use to train on. For more information on selecting specific GPUs, see Select Particular GPUs to Use for Training.
You can modify the network or training parameters on each worker to perform
parameter sweeps in parallel. For example, in networks
is an
array of dlnetwork
objects, you can use code of the following form
to train multiple different networks using the same
data.
parpool ("Processes",numNetworks); parfor idx = 1:numNetworks iteration = 0; velocity = []; % Allocate one network per worker net = networks(idx) % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Custom training loop ... end end % Send the trained networks back to the client. trainedNetworks{idx} = net; end
parfor
finishes, trainedNetworks
contains the resulting networks trained by the workers.Plot Results During Training
To monitor training progress on the workers, you can use a
DataQueue
to send data back from the workers.
To easily specify that the plot should be on or off, create the variable
plots
that contains either
"training-progress"
or
"none"
.
plots = "training-progress";
Before training, initialize the DataQueue
and the animated
lines using the animatedline
function. Create a subplot for
each network you are training.
if plots == "training-progress" f = figure; f.Visible = true; for i=1:numNetworks subplot(numNetworks,1,i) xlabel('Iteration'); ylabel('loss'); lines(i) = animatedline; end end
DataQueue
object. Use afterEach
to
call the helper function displayTrainingProgress
each time
data is sent from the worker to the client.
Q = parallel.pool.DataQueue; displayFcn = @(x) displayTrainingProgress(x,lines); afterEach(Q,displayFcn);
displayTrainingProgress
helper function contains the
code used to add points to the animated
lines.function displayTrainingProgress (data,lines) addpoints(lines(1),double(data(4)),double(data(3))) D = duration(0,0,data(5),'Format','hh:mm:ss'); title("Epoch: " + data(2) + ", Elapsed: " + string(D)) drawnow limitrate nocallbacks end
Inside the training loop, at the end of each iteration, use the
DataQueue
to send the training data from the workers to the
client. Send the parfor
loop index as well as the training
information so that the points are added to the correct line for each
worker.
% Display training progress information.
data = [idx epoch loss iteration toc(start)];
send(Q,gather(data));
Use Experiment Manager to Train in Parallel
You can use Experiment Manager to run your custom training loops in parallel. You can either run multiple trails at the same time, or run a single trial at a time using parallel resources.
To run multiple trials at the same time using one parallel worker for each trial, set up your custom training experiment an enable the Use Parallel option before running your experiment.
To run a single trial at a time using multiple parallel workers, define your
parallel environment in your experiment training function and use an
spmd
block to train the network in parallel. For more
information on training a single network in parallel with a custom training loop,
see Train Single Network in Parallel.
For more information on training in parallel using Experiment Manager, see Use Experiment Manager to Train in Parallel.
See Also
parfor
(Parallel Computing Toolbox) | parfeval
(Parallel Computing Toolbox) | gpuArray
(Parallel Computing Toolbox) | dlarray
| dlnetwork