Main Content


Cross-validate function for classification



    vals = kfoldfun(CVMdl,fun) cross-validates the function fun by applying fun to the data stored in the cross-validated model CVMdl. You must pass fun as a function handle.


    collapse all

    Train a classification tree classifier, and then cross-validate it using a custom k-fold loss function.

    Load Fisher’s iris data set.

    load fisheriris

    Train a classification tree classifier.

    Mdl = fitctree(meas,species);

    Mdl is a ClassificationTree model.

    Cross-validate Mdl using the default 10-fold cross-validation. Compute the classification error (proportion of misclassified observations) for the validation-fold observations.

    rng(1); % For reproducibility
    CVMdl = crossval(Mdl);
    L = kfoldLoss(CVMdl,'LossFun','classiferror')
    L = 0.0467

    Examine the result when the cost of misclassifying a flower as versicolor is 10, and the cost of any other misclassification is 1. Create the custom function noversicolor (shown at the end of this example). This function attributes a cost of 10 for misclassifying a flower as versicolor, and a cost of 1 for any other misclassification.

    Compute the mean misclassification error with the noversicolor cost.

    ans = 0.2267

    This code creates the function noversicolor.

    function averageCost = noversicolor(CMP,~,~,~,Xtest,Ytest,~)
    % noversicolor Example custom cross-validation function
    %    Attributes a cost of 10 for misclassifying versicolor irises, and 1 for
    %    the other irises.  This example function requires the fisheriris data
    %    set.
    Ypredict = predict(CMP,Xtest);
    misclassified = not(strcmp(Ypredict,Ytest)); % Different result
    classifiedAsVersicolor = strcmp(Ypredict,'versicolor'); % Index of bad decisions
    cost = sum(misclassified) + ...
        9*sum(misclassified & classifiedAsVersicolor); % Total differences
    averageCost = cost/numel(Ytest); % Average error

    Input Arguments

    collapse all

    Cross-validated model, specified as a ClassificationPartitionedModel object, ClassificationPartitionedEnsemble object, or ClassificationPartitionedGAM object.

    Cross-validated function, specified as a function handle. fun has the syntax:

    testvals = fun(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
    • CMP is a compact model stored in one element of the CVMdl.Trained property.

    • Xtrain is the training matrix of predictor values.

    • Ytrain is the training array of response values.

    • Wtrain are the training weights for observations.

    • Xtest and Ytest are the test data, with associated weights Wtest.

    • The returned value testvals must have the same size across all folds.

    Data Types: function_handle

    Output Arguments

    collapse all

    Cross-validation results, returned as a numeric matrix. vals contains the arrays of testvals output, concatenated vertically over all folds. For example, if testvals from every fold is a numeric vector of length N, kfoldfun returns a KFold-by-N numeric matrix with one row per fold.

    Data Types: double

    Introduced in R2011a