Main Content

Deep Learning Code Generation on ARM for Fault Detection Using Wavelet Scattering and Recurrent Neural Networks

This example demonstrates code generation for acoustic-based machine fault detection using a wavelet scattering network paired with a recurrent neural network. This example uses MATLAB® Coder™, MATLAB Coder Interface for Deep Learning, and MATLAB Support Package for Raspberry Pi® Hardware to generate a standalone executable (.elf) file on a Raspberry Pi that leverages the performance of the ARM® Compute Library. The input data consists of acoustic time-series recordings from air compressors and the output is the state of the mechanical machine predicted by the LSTM-based RNN network. This standalone executable on Raspberry Pi runs the streaming classifier on the input data received from MATLAB and transfers the computed scores for each label to MATLAB on the host. For more details on audio preprocessing and network training, refer to Fault Detection Using Wavelet Scattering and Recurrent Deep Networks.

Code generation for wavelet time scattering offers significant performance improvement. See Generate and Deploy Optimized Code for Wavelet Time Scattering on ARM Targets for more information.

This example follows these steps:

Prerequisites

For a list of supported compilers and libraries, see Generate Code That Uses Third-Party Libraries (MATLAB Coder).

Prepare Input Data Set

Download the data set and unzip the data file in a folder where you have write permission. The recordings are stored as .wav files in folders named for their respective state.

% Download AirCompressorDataset.zip 
component = "audio";
filename = "AirCompressorDataset/AirCompressorDataset.zip";
localfile = matlab.internal.examples.downloadSupportFile(component,filename);

% Unzip the downloaded zip file to the downloadFolder
downloadFolder = fileparts(localfile);
if ~exist(fullfile(downloadFolder, "AirCompressorDataset"),"dir")
    unzip(localfile, downloadFolder)
end

% Create an audioDatastore object, dataStore, to manage the data
dataStore = audioDatastore(downloadFolder,IncludeSubfolders=true,LabelSource="foldernames");

% Use countEachLabel to get the number of samples of each category in the data set
countEachLabel(dataStore)
ans=8×2 table
      Label      Count
    _________    _____

    Bearing       225 
    Flywheel      225 
    Healthy       225 
    LIV           225 
    LOV           225 
    NRV           225 
    Piston        225 
    Riderbelt     225 

For the classification of audio recordings, construct a wavelet scattering network to extract wavelet scattering coefficients and use them for classification. Each record has 50,000 samples sampled at 16 kHz. Construct a wavelet scattering network based on the data characteristics. Set the invariance scale to 0.5 seconds.

Fs = 16e3;
windowLength = 5e4;
IS = 0.5;
sn = waveletScattering(SignalLength=windowLength,SamplingFrequency=Fs,... 
     InvarianceScale=0.5);

With these network settings, there are 330 scattering paths and 25 time windows per audio record. This leads to a sixfold reduction in the size of the data for each record.

[~,npaths] = paths(sn);
Ncfs = numCoefficients(sn);
sum(npaths)
ans = 330
Ncfs
Ncfs = 25

Initialize signalToBeTested to point to the shuffled dataStore that you downloaded earlier. Pass signalToBeTested to the faultDetect function for classification.

rng default;
dataStore = shuffle(dataStore);
[InputFiles,~] = splitEachLabel(dataStore, 0.5);
signalToBeTested = readall(InputFiles);

Recognize Machine Fault Detection in MATLAB

The faultDetect function reads the input audio samples, calculates the wavelet scattering features, and performs deep learning classification. For more information, enter type faultDetect at the command line.

type faultDetect
function out = faultDetect(in)
%#codegen

%   Copyright 2022 The MathWorks, Inc.

    persistent net;
    if isempty(net)
        net = coder.loadDeepLearningNetwork("faultDetectNetwork.mat");
    end
    persistent sn;
    if isempty(sn)
        windowLength = 5e4;
        Fs = 16e3;
        IS = 0.5;
        sn = waveletScattering(SignalLength=windowLength,SamplingFrequency=Fs, ...
        InvarianceScale=IS);
    end
    
    S = sn.featureMatrix(in,"transform","log");
    TestFeatures = S(2:330,1:25); %Remove the 0-th order scattering coefficients
    out = net.classify(TestFeatures);

end

Pass each audio input to faultDetect, which extracts wavelet scattering coefficients. Pass the coefficients to the LSTM-based RNN network, which classifies and returns the output. Each output maps to eight health states retrieved per input audio. For details on the network creation, refer to Fault Detection Using Wavelet Scattering and Recurrent Deep Networks.

inputCount=1;
numInputs=10; % Validate 10 audio input files
load("faultDetectNetwork.mat");
while inputCount <= numInputs
    
    % Get a frame of audio data
    x = signalToBeTested{inputCount};
    
    % Apply streaming classifier function
    outputLabel(inputCount) = net.Layers(5).Classes(faultDetect(x));

    inputCount = inputCount + 1;
end
scatter(1:numInputs,outputLabel,140,"filled")
xlabel("Audio Input");
ylabel("Machine Health Status");
title("Machine Health Status per Audio Input on Host")

Recognize Machine Fault Detection on Raspberry Pi Using PIL Workflow

This section demonstrates code generation and deployment of machine fault detection using wavelet scattering and RNNs on Raspberry Pi hardware. Use a processor-in-the-loop (PIL) workflow for deployment and profiling. For more information, see SIL/PIL Manager Verification Workflow (Embedded Coder).

Create a code generation configuration object to generate the PIL function.

cfg = coder.config("lib","ecoder",true);
cfg.VerificationMode = "PIL";

Create a deep learning configuration object (dlcfg) for the "arm-compute" library. Set the ARM compute version and architecture, and then attach dlcfg to the coder configuration object.

dlcfg = coder.DeepLearningConfig("arm-compute");
dlcfg.ArmArchitecture = "armv7";
dlcfg.ArmComputeVersion = "20.02.1";
cfg.DeepLearningConfig = dlcfg ;

Use the MATLAB Support Package for Raspberry Pi Hardware function raspi to create a connection to the Raspberry Pi. In this code, replace these keywords and uncomment code:

  • raspiname with the host name of your Raspberry Pi

  • username with your user name

  • password with your password

if (~exist("r","var"))
    r = raspi("raspiname", "username", "password");
end

hw = coder.hardware("Raspberry Pi");
cfg.Hardware = hw;

Specify the build directory and set the target language to C++.

buildDir = "~/remoteBuildDir";
cfg.Hardware.BuildDir = buildDir;
cfg.TargetLang = "C++";

Enable profiling and generate the PIL code. A MEX file named faultDetect_pil is generated in your current folder.

cfg.CodeExecutionProfiling = true;
audioFrame = ones(windowLength,1);
codegen -config cfg faultDetect -args {audioFrame} -silent;
 Deploying code. This may take a few minutes. 
### Connectivity configuration for function 'faultDetect': 'Raspberry Pi'

Call the generated PIL function from MATLAB to get the detected outputs and the execution time.

inputCount=1;
numInputs=10; %Validate 10 audio input files
load("faultDetectNetwork.mat");

while inputCount <= numInputs

    % Get a frame of audio data
    x = signalToBeTested{inputCount};

    % Apply streaming classifier function
    outputLable(inputCount) = net.Layers(5).Classes(faultDetect_pil(x));

    inputCount = inputCount + 1;
end
### Starting application: 'codegen/lib/faultDetect/pil/faultDetect.elf'
    To terminate execution: clear faultDetect_pil
### Launching application faultDetect.elf...
    Execution profiling data is available for viewing. Open Simulation Data Inspector.
    Execution profiling report available after termination.
scatter(1:numInputs,outputLable,140,"filled")
xlabel("Audio Input")
ylabel("Machine Health Status")
title("Machine Health Status per Audio Input on Raspberry Pi")

Terminate the PIL execution.

clear faultDetect_pil;
### Host application produced the following standard output (stdout) and standard error (stderr) messages:

    Execution profiling report: coder.profile.show(getCoderExecutionProfile('faultDetect'))

Generate an execution profile report to evaluate execution time.

executionProfile = getCoderExecutionProfile("faultDetect");
report(executionProfile, ...
       "Units","Seconds", ...
       "ScaleFactor","1e-03", ...
       "NumericFormat","%0.4f");

Summary

In this example, you use the wavelet scattering transform with a simple recurrent network to classify faults in an air compressor. The scattering transform allowed you to extract robust features for the learning problem. Additionally, the data reduction you achieved along the time dimension of the data by using the wavelet scattering transform was critical to create a computationally feasible problem for the recurrent network.

References

[1] Verma, Nishchal K., Rahul Kumar Sevakula, Sonal Dixit, and Al Salour. “Intelligent Condition Based Monitoring Using Acoustic Signals for Air Compressors.” IEEE Transactions on Reliability 65, no. 1 (March 2016): 291–309. https://doi.org/10.1109/TR.2015.2459684.

Copyright 2022, The MathWorks, Inc.

See Also

Related Examples

More About