Classify ECG Signals Using Long Short-Term Memory Networks
This example shows how to classify heartbeat electrocardiogram (ECG) data from the PhysioNet 2017 Challenge using deep learning and signal processing. In particular, the example uses Long Short-Term Memory networks and time-frequency analysis.
For an example that reproduces and accelerates this workflow using a GPU and Parallel Computing Toolbox™, see Classify ECG Signals Using Long Short-Term Memory Networks with GPU Acceleration (Signal Processing Toolbox).
Introduction
ECGs record the electrical activity of a person's heart over a period of time. Physicians use ECGs to detect visually if a patient's heartbeat is normal or irregular.
Atrial fibrillation (AFib) is a type of irregular heartbeat that occurs when the heart's upper chambers, the atria, beat out of coordination with the lower chambers, the ventricles.
This example uses ECG data from the PhysioNet 2017 Challenge [1], [2], [3], which is available at https://physionet.org/challenge/2017/. The data consists of a set of ECG signals sampled at 300 Hz and divided by a group of experts into four different classes: Normal (N), AFib (A), Other Rhythm (O), and Noisy Recording (~). This example shows how to automate the classification process using deep learning. The procedure explores a binary classifier that can differentiate Normal ECG signals from signals showing signs of AFib.
A long short-term memory (LSTM) network is a type of recurrent neural network (RNN) well-suited to study sequence and time-series data. An LSTM network can learn long-term dependencies between time steps of a sequence. The LSTM layer (lstmLayer
) can look at the time sequence in the forward direction, while the bidirectional LSTM layer (bilstmLayer
) can look at the time sequence in both forward and backward directions. This example uses a bidirectional LSTM layer.
This example shows the advantages of using a data-centric approach when solving artificial intelligence (AI) problems. An initial attempt to train the LSTM network using raw data gives substandard results. Training the same model architecture using extracted features leads to a considerable improvement in classification performance.
To accelerate the training process, run this example on a machine with a GPU. If your machine has a GPU and Parallel Computing Toolbox™, then MATLAB® automatically uses the GPU for training; otherwise, it uses the CPU.
Load and Examine Data
Run the ReadPhysionetData
script to download the data from the PhysioNet website and generate a MAT-file (PhysionetData.mat
) that contains the ECG signals in the appropriate format. Downloading the data might take a few minutes. Use a conditional statement that runs the script only if PhysionetData.mat
does not already exist in the current folder.
if ~isfile('PhysionetData.mat') ReadPhysionetData end load PhysionetData
The loading operation adds two variables to the workspace: Signals
and Labels
. Signals
is a cell array that holds the ECG signals. Labels
is a categorical array that holds the corresponding ground-truth labels of the signals.
Signals(1:5)'
ans=1×5 cell array
{1×9000 double} {1×9000 double} {1×18000 double} {1×9000 double} {1×18000 double}
Labels(1:5)'
ans = 1×5 categorical
N N N A A
Use the summary
function to see how many AFib signals and Normal signals are contained in the data.
summary(Labels)
Labels: 5788×1 categorical A 738 N 5050 <undefined> 0
Generate a histogram of signal lengths. Most of the signals are 9000 samples long.
L = cellfun(@length,Signals); h = histogram(L); xticks(0:3000:18000); xticklabels(0:3000:18000); title('Signal Lengths') xlabel('Length') ylabel('Count')
Visualize a segment of one signal from each class. AFib heartbeats are spaced out at irregular intervals while Normal heartbeats occur regularly. AFib heartbeat signals also often lack a P wave, which pulses before the QRS complex in a Normal heartbeat signal. The plot of the Normal signal shows a P wave and a QRS complex.
normal = Signals{1}; aFib = Signals{4}; tiledlayout("flow") nexttile plot(normal) title("Normal Rhythm") xlim([4000,5200]) ylabel("Amplitude (mV)") text(4330,150,"P",HorizontalAlignment="center") text(4370,850,"QRS",HorizontalAlignment="center") nexttile plot(aFib) title("Atrial Fibrillation") xlim([4000,5200]) xlabel("Samples") ylabel("Amplitude (mV)")
Prepare Data for Training
During training, the trainnet
function splits the data into mini-batches. The function then pads or truncates signals in the same mini-batch so they all have the same length. Too much padding or truncating can have a negative effect on the performance of the network, because the network might interpret a signal incorrectly based on the added or removed information.
To avoid excessive padding or truncating, apply the helperSegmentSignals
function to the ECG signals so they are all 9000 samples long. The helperSegmentSignals
function is defined at the end of this example. The function ignores signals with fewer than 9000 samples. If a signal has more than 9000 samples, helperSegmentSignals
breaks it into as many 9000-sample segments as possible and ignores the remaining samples. For example, a signal with 18500 samples becomes two 9000-sample signals, and the remaining 500 samples are ignored.
[Signals,Labels] = helperSegmentSignals(Signals,Labels);
View the first five elements of the Signals
array to verify that each entry is now 9000 samples long.
Signals(1:5)'
ans=1×5 cell array
{1×9000 double} {1×9000 double} {1×9000 double} {1×9000 double} {1×9000 double}
First Attempt: Train Classifier Using Raw Signal Data
To design the classifier, use the raw signals generated in the previous section. Split the signals into a training set to train the classifier and a testing set to test the accuracy of the classifier on new data.
Use the summary
function to show that the ratio of AFib signals to normal signals is approximately 1:7.
summary(Labels)
Labels: 5655×1 categorical A 718 N 4937 <undefined> 0
Because about 7/8 of the signals are Normal, the classifier would learn that it can achieve a high accuracy simply by classifying all signals as Normal. To avoid this bias, augment the AFib data by duplicating AFib signals in the dataset so that there is the same number of Normal and AFib signals. This duplication, commonly called oversampling, is one form of data augmentation used in deep learning.
Split the signals according to their class.
afibX = Signals(Labels=="A"); afibY = Labels(Labels=="A"); normalX = Signals(Labels=="N"); normalY = Labels(Labels=="N");
Next, use splitlabels
to divide targets from each class randomly into training, validation and testing sets.
rng("default") indA = splitlabels(afibY,[0.8 0.1],"randomized"); indN = splitlabels(normalY,[0.8 0.1],"randomized"); XTrainA = afibX(indA{1}); YTrainA = afibY(indA{1}); XTrainN = normalX(indN{1}); YTrainN = normalY(indN{1}); XValidA = afibX(indA{2}); YValidA = afibY(indA{2}); XValidN = normalX(indN{2}); YValidN = normalY(indN{2}); XTestA = afibX(indA{3}); YTestA = afibY(indA{3}); XTestN = normalX(indN{3}); YTestN = normalY(indN{3});
The dataset is imbalanced. To achieve a similar number of AFib and Normal signals, repeat the AFib signals seven times.
By default, the neural network randomly shuffles the data before training, ensuring that contiguous signals do not all have the same label.
XTrain = [repmat(XTrainA,7,1); XTrainN]; YTrain = [repmat(YTrainA,7,1); YTrainN]; XValid = [repmat(XValidA,7,1); XValidN]; YValid = [repmat(YValidA,7,1); YValidN]; XTest = [repmat(XTestA,7,1); XTestN]; YTest = [repmat(YTestA,7,1); YTestN];
The distribution between Normal and AFib signals is now essentially balanced in both the training set, validation set, and the testing set.
summary(YTrain)
YTrain: 7968×1 categorical A 4018 N 3950 <undefined> 0
summary(YValid)
YValid: 997×1 categorical A 504 N 493 <undefined> 0
summary(YTest)
YTest: 998×1 categorical A 504 N 494 <undefined> 0
Define LSTM Network Architecture
LSTM networks can learn long-term dependencies between time steps of sequence data. This example uses the bidirectional LSTM layer bilstmLayer
, as it looks at the sequence in both forward and backward directions.
Because the input signals have one dimension each, specify the input size to be sequences of size 1. Specify a bidirectional LSTM layer with an output size of 50 and output the last element of the sequence. This command instructs the bidirectional LSTM layer to map the input time series into 50 features and then prepares the output for the fully connected layer. Finally, specify two classes by including a fully connected layer of size 2, followed by a softmax layer.
layers = [ ... sequenceInputLayer(1) bilstmLayer(50,OutputMode="last") fullyConnectedLayer(2) softmaxLayer]
layers = 4×1 Layer array with layers: 1 '' Sequence Input Sequence input with 1 dimensions 2 '' BiLSTM BiLSTM with 50 hidden units 3 '' Fully Connected 2 fully connected layer 4 '' Softmax softmax
Next specify the training options for the classifier. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.
Train using the Adam optimizer. Adam performs better than the default stochastic gradient descent with momentum (SGDM) solver with RNNs like LSTMs.
Train for 50 epochs using a mini-batch size of 200.
Set a gradient threshold of 1 to stabilize training.
Shuffle the data every epoch.
Specify the validation data.
Because the training data has sequences with rows and columns corresponding to channels and time steps, respectively, specify the input data format '
CTB'
(channel, time, batch).Display the training progress in a plot and monitor the accuracy of the network during training.
Disable the verbose output.
options = trainingOptions("adam", ... MaxEpochs=50, ... MiniBatchSize=200, ... GradientThreshold=1, ... Shuffle="every-epoch", ... InitialLearnRate=1e-3, ... plots="training-progress", ... Metrics="accuracy", ... InputDataFormats="CTB", ... ValidationData={XValid,YValid}, ... OutputNetwork="best-validation", ... Verbose=false);
Train LSTM Network
Train the LSTM network with the specified training options and layer architecture by using trainnet
. For classification, use cross-entropy loss. By default, the trainnet
function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet
function uses the CPU. To specify the execution environment, use the ExecutionEnvironment
training option. Because the training set is large, the training process can take several minutes. By default, if you specify validation data, then the trainnet
function returns the network with the lowest validation loss.
net = trainnet(XTrain,YTrain,layers,"crossentropy",options);
The top subplot of the training-progress plot represents the training accuracy, which is the classification accuracy on each mini-batch. When training progresses successfully, this value typically increases towards 100%. The bottom subplot displays the training loss, which is the cross-entropy loss on each mini-batch. When training progresses successfully, this value typically decreases towards zero.
If the training is not converging, the plots might oscillate between values without trending in a certain upward or downward direction. This oscillation means that the training accuracy is not improving and the training loss is not decreasing. This situation can occur from the start of training, or the plots might plateau after some preliminary improvement in training accuracy. In many cases, changing the training options can help the network achieve convergence. Decreasing MiniBatchSize
or decreasing InitialLearnRate
might result in a longer training time, but it can help the network learn better.
Here, the training accuracy is very high, but the validation accuracy has not improved correspondingly. This likely indicates overfitting, meaning the model cannot generalize and fits too closely to the training dataset instead. There could be many reasons for this, such as the training data containing a lot of redundant and irrelevant information, and the network not learning the real key factors for classification.
Visualize Training and Testing Accuracy
Calculate the training accuracy, which represents the accuracy of the classifier on the signals on which it was trained. First, classify the training data.
To make predictions with multiple observations, use the minibatchpredict
function. To convert the prediction scores to labels, use the scores2label
function. The minibatchpredict
function automatically uses a GPU if one is available. Otherwise, the function uses the CPU.
classNames = categories(YTrain);
scores = minibatchpredict(net,XTrain,InputDataFormats="CTB");
trainPred = scores2label(scores,classNames);
In classification problems, confusion matrices are used to visualize the performance of a classifier on a set of data for which the true values are known. The Target Class is the ground-truth label of the signal, and the Output Class is the label assigned to the signal by the network. The axes labels represent the class labels, AFib (A) and Normal (N).
Use the confusionchart
command to calculate the overall classification accuracy for the testing data predictions. Specify RowSummary
as "row-normalized"
to display the true positive rates and false positive rates in the row summary. Also, specify ColumnSummary
as "column-normalized"
to display the positive predictive values and false discovery rates in the column summary.
LSTMAccuracy = sum(trainPred == YTrain)/numel(YTrain)*100
LSTMAccuracy = 58.8730
figure confusionchart(YTrain,trainPred,ColumnSummary="column-normalized",... RowSummary="row-normalized",title="Confusion Chart for LSTM");
Now classify the testing data with the same network.
scores = minibatchpredict(net,XTest,InputDataFormats="CTB");
testPred = scores2label(scores,classNames);
Calculate the testing accuracy and visualize the classification performance as a confusion matrix.
LSTMAccuracy = sum(testPred == YTest)/numel(YTest)*100
LSTMAccuracy = 51.0020
figure confusionchart(YTest,testPred,ColumnSummary="column-normalized",... RowSummary="row-normalized",Title="Confusion Chart for LSTM");
Second Attempt: Improve Performance with Feature Extraction
Feature extraction from the data can help improve the performance of the classifier. To decide which features to extract, this example adapts an approach that computes time-frequency images, such as spectrograms, and uses them to train convolutional neural networks (CNNs) [4], [5].
Visualize the spectrogram of each type of signal.
fs = 300; figure tiledlayout("flow") nexttile pspectrum(normal,fs,"spectrogram",TimeResolution=0.5) title("Normal Signal") nexttile pspectrum(aFib,fs,"spectrogram",TimeResolution=0.5) title("AFib Signal")
Because this example uses an LSTM instead of a CNN, it is important to translate the approach so it applies to one-dimensional signals. Time-frequency (TF) moments extract information from the spectrograms. Each moment can be used as a one-dimensional feature to input to the LSTM.
Explore two TF moments in the time domain:
Instantaneous frequency (
instfreq
)Spectral entropy (
pentropy
)
The instfreq
function estimates the time-dependent frequency of a signal as the first moment of the power spectrogram. The function computes a spectrogram using short-time Fourier transforms over time windows. In this example, the function uses 255 time windows. The time outputs of the function correspond to the centers of the time windows.
Visualize the instantaneous frequency for each type of signal.
[instFreqA,tA] = instfreq(aFib,fs); [instFreqN,tN] = instfreq(normal,fs); figure tiledlayout("flow") nexttile plot(tN,instFreqN) title("Normal Signal") xlabel("Time (s)") ylabel("Instantaneous Frequency") nexttile plot(tA,instFreqA) title("AFib Signal") xlabel("Time (s)") ylabel("Instantaneous Frequency")
Use cellfun
to apply the instfreq
function to every cell in the training and testing sets.
instfreqTrain = cellfun(@(x)instfreq(x,fs)',XTrain,UniformOutput=false); instfreqTest = cellfun(@(x)instfreq(x,fs)',XTest,UniformOutput=false); instfreqValid = cellfun(@(x)instfreq(x,fs)',XValid,UniformOutput=false);
The spectral entropy measures how spiky flat the spectrum of a signal is. A signal with a spiky spectrum, like a sum of sinusoids, has low spectral entropy. A signal with a flat spectrum, like white noise, has high spectral entropy. The spectralEntropy
function estimates the spectral entropy based on a power spectrogram. As with the instantaneous frequency estimation case.
[s,f,tA2] = pspectrum(aFib,fs,"spectrogram"); pentropyA = spectralEntropy(s,f,Scaled=true); [s,f,tN2] = pspectrum(normal,fs,"spectrogram"); pentropyN = spectralEntropy(s,f,Scaled=true);
Visualize the spectral entropy for each type of signal.
figure tiledlayout("flow") nexttile plot(tN2,pentropyN) title("Normal Signal") ylabel("Spectral Entropy") nexttile plot(tA2,pentropyA) title("AFib Signal") xlabel("Time (s)") ylabel("Spectral Entropy")
Use cellfun
to apply the helperComputeSpecEntropy
function to every cell in the training, testing and validation sets.
pentropyTrain = cellfun(@(x)helperComputeSpecEntropy(x,fs)',XTrain,UniformOutput=false); pentropyTest = cellfun(@(x)helperComputeSpecEntropy(x,fs)',XTest,UniformOutput=false); pentropyValid = cellfun(@(x)helperComputeSpecEntropy(x,fs)',XValid,UniformOutput=false);
Concatenate the features such that each cell in the new training and testing sets has two dimensions, or two features.
XTrain2 = cellfun(@(x,y)[x;y],instfreqTrain,pentropyTrain,UniformOutput=false); XTest2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,UniformOutput=false); XValid2 = cellfun(@(x,y)[x;y],instfreqValid,pentropyValid,UniformOutput=false);
Visualize the format of the new inputs. Each cell no longer contains one 9000-sample-long signal; now it contains two 255-sample-long features.
XTrain2(1:5)'
ans=1×5 cell array
{2×255 double} {2×255 double} {2×255 double} {2×255 double} {2×255 double}
Standardize Data
The instantaneous frequency and the spectral entropy have means that differ by almost one order of magnitude. Furthermore, the instantaneous frequency mean might be too high for the LSTM to learn effectively. When a network is fit on data with a large mean and a large range of values, large inputs could slow down the learning and convergence of the network [6].
mean(instFreqN)
ans = 5.5551
mean(pentropyN)
ans = 0.6324
Use the training set mean and standard deviation to standardize the training, testing and validation sets. Standardization, or z-scoring, is a popular way to improve network performance during training.
XV = [XTrain2{:}]; mu = mean(XV,2); sg = std(XV,[],2); XTrainSD = XTrain2; XTrainSD = cellfun(@(x)(x-mu)./sg,XTrainSD,UniformOutput=false); XValidSD = XValid2; XValidSD = cellfun(@(x)(x-mu)./sg,XValidSD,UniformOutput=false); XTestSD = XTest2; XTestSD = cellfun(@(x)(x-mu)./sg,XTestSD,UniformOutput=false);
Show the means of the standardized instantaneous frequency and spectral entropy.
instFreqNSD = XTrainSD{1}(1,:); pentropyNSD = XTrainSD{1}(2,:); mean(instFreqNSD)
ans = 0.1545
mean(pentropyNSD)
ans = 0.1936
Modify LSTM Network Architecture
Now that the signals each have two dimensions, it is necessary to modify the network architecture by specifying the input sequence size as 2. Specify a bidirectional LSTM layer with 50 hidden units, and output the last element of the sequence. Specify two classes by including a fully connected layer of size 2, followed by a softmax layer.
layers = [ ... sequenceInputLayer(2) bilstmLayer(50,OutputMode="last") fullyConnectedLayer(2) softmaxLayer]
layers = 4×1 Layer array with layers: 1 '' Sequence Input Sequence input with 2 dimensions 2 '' BiLSTM BiLSTM with 50 hidden units 3 '' Fully Connected 2 fully connected layer 4 '' Softmax softmax
Specify the training options. The training options are identical to those used to train the first network, with the exception of training for 150 epochs instead of 50 and specifying different validation data.
options = trainingOptions("adam", ... MaxEpochs=150, ... MiniBatchSize=200, ... GradientThreshold=1, ... Shuffle="every-epoch", ... InitialLearnRate=1e-3, ... plots="training-progress", ... Metrics="accuracy", ... InputDataFormats="CTB", ... ValidationData={XValidSD,YValid}, ... OutputNetwork="best-validation", ... Verbose=false);
Train LSTM Network with Time-Frequency Features
Train the LSTM network with the specified training options and layer architecture by using trainnet
.
net2 = trainnet(XTrainSD,YTrain,layers,"crossentropy",options);
The time required for training decreases because the TF moments are shorter than the raw sequences.
Visualize Training and Testing Accuracy
Classify the training data using the updated LSTM network. Visualize the classification performance as a confusion matrix.
scores = minibatchpredict(net2,XTrainSD,InputDataFormats="CTB");
trainPred2 = scores2label(scores,classNames);
LSTMAccuracy = sum(trainPred2 == YTrain)/numel(YTrain)*100
LSTMAccuracy = 95.5447
figure confusionchart(YTrain,trainPred2,ColumnSummary="column-normalized",... RowSummary="row-normalized",Title="Confusion Chart for LSTM");
Classify the testing data with the updated network. Plot the confusion matrix to examine the testing accuracy.
scores = minibatchpredict(net2,XTestSD,InputDataFormats="CTB");
testPred2 = scores2label(scores,classNames);
LSTMAccuracy = sum(testPred2 == YTest)/numel(YTest)*100
LSTMAccuracy = 94.7896
figure confusionchart(YTest,testPred2,ColumnSummary="column-normalized",... RowSummary="row-normalized",Title="Confusion Chart for LSTM");
Conclusion
This example shows how to build a classifier to detect atrial fibrillation in ECG signals using an LSTM network. The procedure uses oversampling to avoid the classification bias that occurs when one tries to detect abnormal conditions in populations composed mainly of healthy patients. Training the LSTM network using raw signal data results in a poor classification accuracy. Training the network using two time-frequency-moment features for each signal significantly improves the classification performance and also decreases the training time.
References
[1] AF Classification from a Short Single Lead ECG Recording: the PhysioNet/Computing in Cardiology Challenge, 2017. https://physionet.org/challenge/2017/
[2] Clifford, Gari, Chengyu Liu, Benjamin Moody, Li-wei H. Lehman, Ikaro Silva, Qiao Li, Alistair Johnson, and Roger G. Mark. "AF Classification from a Short Single Lead ECG Recording: The PhysioNet Computing in Cardiology Challenge 2017." Computing in Cardiology (Rennes: IEEE). Vol. 44, 2017, pp. 1–4.
[3] Goldberger, A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, and H. E. Stanley. "PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals". Circulation. Vol. 101, No. 23, 13 June 2000, pp. e215–e220. http://circ.ahajournals.org/content/101/23/e215.full
[4] Pons, Jordi, Thomas Lidy, and Xavier Serra. "Experimenting with Musically Motivated Convolutional Neural Networks". 14th International Workshop on Content-Based Multimedia Indexing (CBMI). June 2016.
[5] Wang, D. "Deep learning reinvents the hearing aid," IEEE Spectrum, Vol. 54, No. 3, March 2017, pp. 32–37. doi: 10.1109/MSPEC.2017.7864754.
[6] Brownlee, Jason. How to Scale Data for Long Short-Term Memory Networks in Python. 7 July 2017. https://machinelearningmastery.com/how-to-scale-data-for-long-short-term-memory-networks-in-python/.
Appendix — Helper Functions
helperSegmentSignals
— this function takes a cell array of ECG signals, signalsIn
, and a cell array of labels, labelsIn
. To avoid excessive padding or truncating, the function segments each signal, ignoring signals with fewer than 9000 samples, and breaks the remaining signals into as many 9000-sample segments as possible, ignoring any remaining samples. For example, a signal with 18500 samples becomes two 9000-sample signals, and the remaining 500 samples are ignored. If the function segments a signal into several 9000-sample segments, the function labels each segment with the label of the original signal.
function [signalsOut,labelsOut] = helperSegmentSignals(signalsIn,labelsIn) % This functions is only for use in this example. It might change or be removed in a future release. % Specify the length of each segmented signal. targetLength = 9000; signalsOut = {}; labelsOut = {}; for idx = 1:numel(signalsIn) % Extract signal and label from cell array. x = signalsIn{idx}; y = labelsIn(idx); % Skip signals shorter than 9000 samples. if length(x) < targetLength continue; end % Partition signal into frames of length 9000. x = framesig(x,targetLength); % Compute the number of targetLength-sample chunks in the signal. numSigs = size(x,2); % Repeat the label numSigs times. y = repmat(y,[numSigs,1]); % Vertically concatenate into cell arrays. signalsOut = [signalsOut; mat2cell(x.',ones(numSigs,1))]; %#ok<AGROW> labelsOut = [labelsOut; cellstr(y)]; %#ok<AGROW> end % Convert labels to categorical. labelsOut = categorical(labelsOut); end
helperComputeSpecEntropy
— this function computes the spectral entropy.
function [specentropy,t] = helperComputeSpecEntropy(signal,fs) [s,f,t] = pspectrum(signal,fs,"spectrogram"); specentropy = spectralEntropy(s,f,Scaled=true); end
See Also
Functions
instfreq
(Signal Processing Toolbox) |pentropy
(Signal Processing Toolbox) |trainnet
|trainingOptions
|dlnetwork
|bilstmLayer
|lstmLayer