How do I determine query points?
Show older comments
Hi,
I've recently been introduced to Shapley values and I'm trying to follow the Matlab tutorial on these but I've unsure about the section on query Points.
In my instance, I am trying to determine the importance of different predictors from a classification tree based on a table of 19 different possible predictors. The output from this tree will give me a category of either A, B, C or D, with A being the preferential category.
In the tutorial on matlab, when determining query points, the code written is:
queryPoint = tbl(end,:)
By this, it is taking the last line of the tbl in question. This may be a stupid question, but how do I adapt this to my work? Do I replace it with something along the lines of:
queryPoint = tbl(x,:)
where x is the line number of a datapoint that is categorised as A
Apolgies if it sounds dumb, I'm very new to the world of machine learning techniques but thank you for any help?
Accepted Answer
More Answers (1)
Sahas
on 4 Sep 2024
As per my understanding, you would like to analyze a specific data point associated to a specific category. You are using “queryPoint” property to understand the contribution of each feature in the model.
Assuming you are using MATALB R2023b version, specify the query point by using the “QueryPoint” name-value argument instead of “QueryPoints”. This is because before MATLAB R2024a, you can only specify one query point for calculating the “Shapley values”.
Refer to the following MathWorks documentation of MATLAB R2023b version to know how to compute “Shapley values” for a query point: https://www.mathworks.com/help/releases/R2023b/stats/shapley.html
In the above cited example, you can replace the “end” keyword with the index of that particular row to analyze a specific data point in category “A”. Take a look at the following code snippet for a better understanding:
% Assume "Category" is the name of the column
% Find the index of the first occurrence of category "A"
x = find(tbl.Category == 'A', 1);
% Select that row to analyze
queryPoint = tbl(x,:);
Follow the steps in the given in the documentation to compute ”Shapley values” with a single or multiple query points according to the requirements and the MATLAB version being used.
Refer to the following MathWorks documentation of MATLAB R2024a version to know how to compute “Shapley values” for multiple query points: https://www.mathworks.com/help/stats/shapley.html
Hope this is beneficial!
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!