This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English version of the page.

Note: This page has been translated by MathWorks. Click here to see
To view all translated materials including this page, select Country from the country navigator on the bottom of this page.

Train a Deep Learning Vehicle Detector

This example shows how to train a vision-based vehicle detector using deep learning.

Overview

Vehicle detection using computer vision is an important component for tracking vehicles around the ego vehicle. The ability to detect and track vehicles is required for many autonomous driving applications, such as for forward collision warning, adaptive cruise control, and automated lane keeping. Automated Driving Toolbox™ provides pretrained vehicle detectors (vehicleDetectorFasterRCNN and vehicleDetectorACF) to enable quick prototyping. However, the pretrained models might not suit every application, requiring you to train from scratch. This example shows how to train a vehicle detector from scratch using deep learning.

Deep learning is a powerful machine learning technique that automatically learns image features required for detection tasks. There are several techniques for object detection using deep learning such as Faster R-CNN and you only look once (YOLO) v2. This example uses the Faster R-CNN [1] technique, which is implemented in the trainFasterRCNNObjectDetector function.

To learn more, see Object Detection using Deep Learning (Computer Vision Toolbox).

Note: This example requires Deep Learning Toolbox™. Parallel Computing Toolbox™ is recommended to train the detector using a CUDA-capable NVIDIA™ GPU with compute capability 3.0.

Download Pretrained Detector

This example uses a pretrained detector to allow the example to run without having to wait for training to complete. If you want to train the detector with the trainFasterRCNNObjectDetector function, set the doTrainingAndEval variable to true. Otherwise, download the pretrained detector.

doTrainingAndEval = false;
if ~doTrainingAndEval && ~exist('fasterRCNNResNet50VehicleExample.mat','file')
    % Download pretrained detector.
    disp('Downloading pretrained detector (118 MB)...');
    pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/fasterRCNNResNet50VehicleExample.mat';
    websave('fasterRCNNResNet50VehicleExample.mat',pretrainedURL);
end

Load Dataset

This example uses a small vehicle data set that contains 295 images. Each image contains one or two labeled instances of a vehicle. A small data set is useful for exploring the Faster R-CNN training procedure, but in practice, more labeled images are needed to train a robust detector.

% Unzip vehicle dataset images.
unzip vehicleDatasetImages.zip

% Load vehicle dataset ground truth.
data = load('vehicleDatasetGroundTruth.mat');
vehicleDataset = data.vehicleDataset;

The ground truth data is stored in a table. The first column contains the path to the image files. The remaining columns contain the ROI labels for vehicles.

% Display first few rows of the data set.
vehicleDataset(1:4,:)
ans =

  4×2 table

             imageFilename               vehicle
    _______________________________    ____________

    'vehicleImages/image_00001.jpg'    [1×4 double]
    'vehicleImages/image_00002.jpg'    [1×4 double]
    'vehicleImages/image_00003.jpg'    [1×4 double]
    'vehicleImages/image_00004.jpg'    [1×4 double]

Display one of the images from the data set to understand the type of images it contains.

% Add the fullpath to the local vehicle data folder.
vehicleDataset.imageFilename = fullfile(pwd, vehicleDataset.imageFilename);

% Read one of the images.
I = imread(vehicleDataset.imageFilename{10});

% Insert the ROI labels.
I = insertShape(I, 'Rectangle', vehicleDataset.vehicle{10});

% Resize and display image.
I = imresize(I,3);
figure
imshow(I)

Split the data set into a training set for training the detector, and a test set for evaluating the detector. Select 60% of the data for training. Use the rest for evaluation.

% Set random seed to ensure example training reproducibility.
rng(0);

% Randomly split data into a training and test set.
shuffledIdx = randperm(height(vehicleDataset));
idx = floor(0.6 * height(vehicleDataset));
trainingData = vehicleDataset(shuffledIdx(1:idx),:);
testData = vehicleDataset(shuffledIdx(idx+1:end),:);

Configure Training Options

trainFasterRCNNObjectDetector trains the detector in four steps. The first two steps train the region proposal and detection networks used in Faster R-CNN. The final two steps combine the networks from the first two steps such that a single network is created for detection [1]. Specify the network training options for all steps using trainingOptions.

% Options for step 1.
options = trainingOptions('sgdm', ...
    'MaxEpochs', 5, ...
    'MiniBatchSize', 1, ...
    'InitialLearnRate', 1e-3, ...
    'CheckpointPath', tempdir);

The 'MiniBatchSize' property is set to 1 because the vehicle dataset has images with different sizes. The prevents them from being batched together for processing. Choose a MiniBatchSize greater than 1 if the training images are all the same size to reduce training time.

The 'CheckpointPath' property is set to a temporary location for all the training options. This name-value pair enables the saving of partially trained detectors during the training process. If training is interrupted, such as from a power outage or system failure, you can resume training from the saved checkpoint.

Train Faster R-CNN

The Faster R-CNN object detection network is composed of a feature extraction network followed by two sub-networks. The feature extraction network is typically a pretrained CNN such as ResNet-50 or Inception v3. For more details, see Pretrained Deep Neural Networks (Deep Learning Toolbox). The first sub-network following the feature extraction network is a region proposal network (RPN) trained to generate object proposals (object or background). The second sub-network is trained to predict the actual class of each proposal (car or person).

This example uses a pretrained ResNet-50 for feature extraction. Other pretrained networks such as MobileNet v2 or ResNet-18 can also be used depending on application requirements. The trainFasterRCNNObjectDetector function automatically adds the sub-networks required for object detection. You can also create a custom Faster R-CNN network. See Create Faster R-CNN Object Detection Network (Computer Vision Toolbox).

Train Faster R-CNN object detector if doTrainingAndEval is true. Otherwise, you can load a pretrained network.

if doTrainingAndEval

    % Train Faster R-CNN detector.
    %  * Use 'resnet50' as the feature extraction network.
    %  * Adjust the NegativeOverlapRange and PositiveOverlapRange to ensure
    %    training samples tightly overlap with ground truth.
    [detector, info] = trainFasterRCNNObjectDetector(trainingData, 'resnet50', options, ...
        'NegativeOverlapRange', [0 0.3], ...
        'PositiveOverlapRange', [0.6 1]);
else
    % Load pretrained detector for the example.
    pretrained = load('fasterRCNNResNet50VehicleExample.mat');
    detector = pretrained.detector;
end

% Note: This example verified on an NVIDIA(TM) Titan X with 12 GB of GPU
% memory. Training this network took approximately 10 minutes using this setup.
% Training time varies depending on the hardware you use.

To quickly verify the training, run the detector on a test image.

% Read a test image.
I = imread(testData.imageFilename{1});

% Run the detector.
[bboxes, scores] = detect(detector, I);

% Annotate detections in the image.
I = insertObjectAnnotation(I, 'rectangle', bboxes, scores);
figure
imshow(I)

Evaluate Detector Using Test Set

Testing a single image showed promising results. To fully evaluate the detector, testing it on a larger set of images is recommended. Computer Vision Toolbox™ provides object detector evaluation functions to measure common metrics such as average precision (evaluateDetectionPrecision) and log-average miss rates (evaluateDetectionMissRate). Here, the average precision metric is used. The average precision provides a single number that incorporates the ability of the detector to make correct classifications (precision) and the ability of the detector to find all relevant objects (recall).

The first step for detector evaluation is to collect the detection results by running the detector on the test set. To avoid long evaluation time, the results are loaded from disk. Set the doTrainingAndEval flag from the previous section to true to execute the evaluation locally.

if doTrainingAndEval
    % Create a table to hold the bounding boxes, scores, and labels output by
    % the detector.
    numImages = height(testData);
    results = table('Size',[numImages 3],...
        'VariableTypes',{'cell','cell','cell'},...
        'VariableNames',{'Boxes','Scores','Labels'});

    % Run detector on each image in the test set and collect results.
    for i = 1:numImages

        % Read the image.
        I = imread(testData.imageFilename{i});

        % Run the detector.
        [bboxes, scores, labels] = detect(detector, I);

        % Collect the results.
        % Collect the results.
        results.Boxes{i} = bboxes;
        results.Scores{i} = scores;
        results.Labels{i} = labels;
    end
else
    % Load pretrained detector for the example.
    pretrained = load('fasterRCNNResNet50VehicleExample.mat');
    results = pretrained.results;
end

% Extract expected bounding box locations from test data.
expectedResults = testData(:, 2:end);

% Evaluate the object detector using average precision metric.
[ap, recall, precision] = evaluateDetectionPrecision(results, expectedResults);

The precision/recall (PR) curve highlights how precise a detector is at varying levels of recall. Ideally, the precision would be 1 at all recall levels. The use of additional layers in the network can help improve the average precision, but might require additional training data and longer training time.

% Plot precision/recall curve
figure
plot(recall, precision)
xlabel('Recall')
ylabel('Precision')
grid on
title(sprintf('Average Precision = %.2f', ap))

Summary

This example showed how to train a vehicle detector using deep learning. You can follow similar steps to train detectors for traffic signs, pedestrians, or other objects.

To learn more about deep learning, see Object Detection using Deep Learning (Computer Vision Toolbox).

References

[1] Ren, Shaoqing, et al. "Faster R-CNN: Towards Real-Time Object detection with Region Proposal Networks." Advances in Neural Information Processing Systems. 2015.

See Also

Functions

Related Topics