Main Content

Classify Text Data Using Convolutional Neural Network

This example shows how to classify text data using a convolutional neural network.

To classify text data using convolutions, use 1-D convolutional layers that convolve over the time dimension of the input.

This example trains a network with 1-D convolutional filters of varying widths. The width of each filter corresponds the number of words the filter can see (the n-gram length). The network has multiple branches of convolutional layers, so it can use different n-gram lengths.

Load Data

Create a tabular text datastore from the data in factoryReports.csv and view the first few reports.

data = readtable("factoryReports.csv");
                                  Description                                         Category            Urgency            Resolution          Cost 
    _______________________________________________________________________    ______________________    __________    ______________________    _____

    {'Items are occasionally getting stuck in the scanner spools.'        }    {'Mechanical Failure'}    {'Medium'}    {'Readjust Machine'  }       45
    {'Loud rattling and banging sounds are coming from assembler pistons.'}    {'Mechanical Failure'}    {'Medium'}    {'Readjust Machine'  }       35
    {'There are cuts to the power when starting the plant.'               }    {'Electronic Failure'}    {'High'  }    {'Full Replacement'  }    16200
    {'Fried capacitors in the assembler.'                                 }    {'Electronic Failure'}    {'High'  }    {'Replace Components'}      352
    {'Mixer tripped the fuses.'                                           }    {'Electronic Failure'}    {'Low'   }    {'Add to Watch List' }       55
    {'Burst pipe in the constructing agent is spraying coolant.'          }    {'Leak'              }    {'High'  }    {'Replace Components'}      371
    {'A fuse is blown in the mixer.'                                      }    {'Electronic Failure'}    {'Low'   }    {'Replace Components'}      441
    {'Things continue to tumble off of the belt.'                         }    {'Mechanical Failure'}    {'Low'   }    {'Readjust Machine'  }       38

Partition the data into training and validation partitions. Use 80% of the data for training and the remaining data for validation.

cvp = cvpartition(data.Category,Holdout=0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

Preprocess Text Data

Extract the text data from the "Description" column of the table and preprocess it using the preprocessText function, listed in the section Preprocess Text Function of the example.

documentsTrain = preprocessText(dataTrain.Description);

Extract the labels from the "Category" column and convert them to categorical.

TTrain = categorical(dataTrain.Category);

View the class names and the number of observations.

classNames = unique(TTrain)
classNames = 4×1 categorical
     Electronic Failure 
     Mechanical Failure 
     Software Failure 

numObservations = numel(TTrain)
numObservations = 384

Extract and preprocess the validation data using the same steps.

documentsValidation = preprocessText(dataValidation.Description);
TValidation = categorical(dataValidation.Category);

Convert Documents to Sequences

To input the documents into a neural network, use a word encoding to convert the documents into sequences of numeric indices.

Create a word encoding from the documents.

enc = wordEncoding(documentsTrain);

View the vocabulary size of the word encoding. The vocabulary size is the number of unique words of the word encoding.

numWords = enc.NumWords
numWords = 438

Convert the documents to sequences of integers using the doc2sequence function.

XTrain = doc2sequence(enc,documentsTrain);

Convert the validation documents to sequences using the word encoding created from the training data.

XValidation = doc2sequence(enc,documentsValidation);

Define Network Architecture

Define the network architecture for the classification task.

The following steps describe the network architecture.

  • Specify an input size of 1, which corresponds to the channel dimension of the integer sequence input.

  • Embed the input using a word embedding of dimension 100.

  • For the n-gram lengths 2, 3, 4, and 5, create blocks of layers containing a convolutional layer, a batch normalization layer, a ReLU layer, a dropout layer, and a max pooling layer.

  • For each block, specify 200 convolutional filters of size 1-by-N and a global max pooling layer.

  • Connect the input layer to each block and concatenate the outputs of the blocks using a concatenation layer.

  • To classify the outputs, include a fully connected layer with output size K, and a softmax layer, where K is the number of classes.

Specify the network hyperparameters.

embeddingDimension = 100;
ngramLengths = [2 3 4 5];
numFilters = 200;

First, create a dlnetwork object containing the input layer and a word embedding layer of dimension 100. To help connect the word embedding layer to the convolution layers, set the word embedding layer name to "emb". To check that the convolution layers do not convolve the sequences to have a length of zero during training, set the MinLength option to the length of the shortest sequence in the training data.

net = dlnetwork;
minLength = min(doclength(documentsTrain));
layers = [ 
net = addLayers(net,layers);

For each of the n-gram lengths, create a block of 1-D convolution, batch normalization, ReLU, dropout, and 1-D global max pooling layers. Connect each block to the word embedding layer.

numBlocks = numel(ngramLengths);
for j = 1:numBlocks
    N = ngramLengths(j);
    block = [
    net = addLayers(net,block);
    net = connectLayers(net,"emb","conv"+N);

Add the concatenation layer, the fully connected layer, and the softmax layer.

numClasses = numel(classNames);

layers = [

net = addLayers(net,layers);

Connect the global max pooling layers to the concatenation layer and view the network architecture in a plot.

for j = 1:numBlocks
    N = ngramLengths(j);
    net = connectLayers(net,"max"+N,"cat/in"+j);

title("Network Architecture")

Train Network

Specify the training options:

  • Train with a mini-batch size of 128. By default, the trainnet (Deep Learning Toolbox) function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

  • Validate the network using the validation data.

  • Return the network with the lowest validation loss.

  • Monitor the training progress in a plot and monitor the accuracy metric.

  • Suppress the verbose output.

  • Because the training data has sequences with rows and columns corresponding to channels and time steps, respectively, specify the input data format "CTB" (channel, time, batch).

options = trainingOptions("adam", ...
    MiniBatchSize=128, ...
    ValidationData={XValidation,TValidation}, ...
    OutputNetwork="best-validation", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false, ...

Train the network using the trainnet function.

net = trainnet(XTrain,TTrain,net,"crossentropy",options);

Test Network

Make predictions using the neural network. To make predictions with multiple observations, use the minibatchpredict function. To convert the prediction scores to labels, use the scores2label function. The minibatchpredict function automatically uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU.

Because the data has sequences with rows and columns corresponding to channels and time steps, respectively, specify the input data format "CTB" (channel, time, batch).

scores = minibatchpredict(net,XValidation,InputDataFormats="CTB");
YValidation = scores2label(scores,classNames);

Visualize the predictions in a confusion chart.


Calculate the classification accuracy. The accuracy is the proportion of labels predicted correctly.

accuracy = mean(TValidation == YValidation)
accuracy = 0.9688

Predict Using New Data

Classify the event type of three new reports. Create a string array containing the new reports.

reportsNew = [ 
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

Preprocess the text data using the preprocessing steps as the training and validation documents.

documentsNew = preprocessText(reportsNew);
XNew = doc2sequence(enc,documentsNew);

Classify the new sequences using the trained network.

scores = minibatchpredict(net,XNew,InputDataFormats="CTB");
YNew = scores2label(scores,classNames)
YNew = 3×1 categorical
     Electronic Failure 
     Mechanical Failure 

Preprocess Text Function

The preprocessTextData function takes text data as input and performs these steps:

  1. Tokenize the text.

  2. Convert the text to lowercase.

function documents = preprocessText(textData)

documents = tokenizedDocument(textData);
documents = lower(documents);


See Also

| | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | | |

Related Topics