Better Interpretability Leads to Better Adoption
Is your highly-trained model easy to understand? A sophisticated machine learning algorithms usually can produce accurate predictions, but its notorious “black box” nature does not help adoption at all. Think about this: If you ask me to swallow a black pill without telling me what’s in it, I certainly don’t want to swallow it. The interpretability of a model is like the label on a drug bottle. We need to make our effective pill transparent for easy adoption.
How can we do that? The SHAP value is a great tool among others like LIME or ELI5. In this article I will present to you what the Shapley value is and how the SHAP (SHapley Additive exPlanations) value emerge from the Shapley concept. I will demonstrate to you how the SHAP values increase model transparency. This article also comes with Python code for you to produce nice results in your applications.
What is the Shapley Value?
Let me explain the Shapley value with a story: Assume Ann, Bob and Cindy together were hammering an “error” wood log, 38 inches, to the ground. The log was driven down to the ground and they went to a local bar for a drink. I, a mathematician, came to join them. I asked a very bizarre question: “What is everyone’s contribution (in inches)?”
How to answer this question? I listed all the permutations and came up with the data in Table A.(Don’t ask me how I got it! lol.) When the ordering is A, B, C, the marginal contributions of the three are 4, 30, and 4 inches respectively.
The table shows the coalition of (A,B) or (B,A) is 34 inches, so the marginal contribution of C to this coalition is 4 inches. I took the average of all the permutations for each person to get each individual’s contribution: Ann is 2 inches, Bob is 32 inches and Cindy is 4 inches. That’s the way to calculate the Shapley value: It is the average of the marginal contributions across all permutations. I will describe the calculation in the formal mathematical term in the end of this post. But now, let’s see how it is applied in machine learning.
I called the wood log the “error” log for a special reason: It is the loss function in the context of machine learning. The “error” is the difference between the actual value and prediction. The hammers are the predictors to attack the error log. How do we measure the contributions of the hammers (predictors)? The Shapley values!
From the Shaley Value to SHAP (SHapley Additive exPlanations)
The SHAP (SHapley Additive exPlanations) deserves its own space rather than an extension of the Shapley value. Inspired by several methods (1,2,3,4,5,6,7) on model interpretability, Lundberg and Lee (2016) proposed the SHAP value as a united approach to explain the output of any machine learning model. Three benefits worth mentioning here.
- The first one is global interpretability — the collective SHAP values can show how much each predictor contributes, either positively or negatively, to the target variable. This is like the variable importance plot but it is able to show the positive or negative relationship for each variable with the target (see the SHAP value plot below).
- The second benefit is local interpretability — each observation gets its own set of SHAP values (see the individual SHAP value plot below). This greatly increases its transparency. We can explain why a case receives its prediction and the contributions of the predictors. Traditional variable importance algorithms only show the results across the entire population but not on each individual case. The local interpretability enables us to pinpoint and contrast the impacts of the factors.
- Third, the SHAP values can be calculated for any tree-based model, while other methods use linear regression or logistic regression models as the surrogate models.
How to Use It in Python?
I am going to use the red wine quality data in Kaggle.com to do the analysis. The target value of this dataset is the quality rating from low to high (0–10). The input variables are the content of each wine sample including fixed acidity, volatile acidity, citric acid, residual sugar, chlorides, free sulfur dioxide, total sulfur dioxide, density, pH, sulphates and alcohol. There are 1,599 wine samples. I build a random forest regression model and call it “model”.
(A) Variable Importance Plot — Global Interpretability
shap.summary_plot function with
plot_type=”bar” let you produce the variable importance plot. A variable importance plot lists the most significant variables in descending order. The top variables contribute more to the model than the bottom ones and thus have high predictive power.
Variable Importance Plot
The SHAP value plot can further show the positive and negative relationships of the predictors with the target variable. The code
shap.summary_plot(shap_values, X_train)produces the following plot:
Exhibit (K): The SHAP Variable Importance Plot
This plot is made of all the dots in the train data. The vertical location shows the feature importance. The horizontal location shows whether the effect of that value caused a higher or lower prediction. Color shows whether that variable is high or low for that observation. For example, “alcohol” is positively correlated to the target variable “wine quality rating”, and “volatile acidity” is negatively correlated with the target variable.
(B) SHAP Dependence Plot — Global Interpretability
You may ask how to show a partial dependence plot. The partial dependence plot shows the marginal effect one or two features have on the predicted outcome of a machine learning model (J. H. Friedman 2001). It tells whether the relationship between the target and a feature is linear, monotonic or more complex. In order to create a dependence plot, you only need one line of code:
shap.dependence_plot(“alcohol”, shap_values, X_train). The function automatically includes another variable that your chosen variable interacts most with. The following plot shows there is an approximately linear and positive trend between “alcohol” and the target variable, and “alcohol” interacts with “sulphates” frequently.
The SHAP Dependence Plot
Suppose you want to know “volatile acidity” and the variable that it interacts the most, you can do
shap.dependence_plot(“volatile acidity”, shap_values, X_train). The plot below shows there exists an approximately linear but negative relationship between “volatile acidity” and the target variable. This negative relationship is already demonstrated in the variable importance plot Exhibit (K).
The SHAP Dependence Plot
(C) Individual SHAP Value Plot — Local Interpretability
In order to show you how the SHAP values can be done on individual cases, I will execute on several observations. I randomly chose a few observations in as shown in Table B below:
Table B: Data S contains some random observations of X_test
If you use Jupyter notebook, you will need to initialize it with initjs(). To save space, I write a small function
shap_plot(j) to execute the SHAP values for the observations in Table B.
When I execute
shap_plot(0) I get the result for the first row of Table B:
Individual SHAP Value Plot for Observation 0 of S
Let me describe this elegant plot in great detail:
- The output value is the prediction for that observation (the prediction of the first row in Table B is 6.20).
- The base value: The original paper explains that the base value E(y_hat) is “the value that would be predicted if we did not know any features for the current output.” In other words, it is the mean prediction, or mean(yhat). You may wonder why it is 5.62. This is because the mean prediction of Y_test is 5.62. You can test it out by
Y_test.mean()which produces 5.62.
- Red/blue: Features that push the prediction higher (to the right) are shown in red, and those pushing the prediction lower are in blue.
- Alcohol: has positive impact on the quality rating. The alcohol of this wine is 11.8 (as shown in the first row of Table B) which is higher than the average value 10.41. So it pushes the prediction to the right.
- pH: has a negative impact on the quality rating. A lower than the average pH (=3.26 < 3.0) drives the prediction to the right.
- Sulphates: is positively related to the quality rating. A lower than the average Sulphates (= 0.64 < 0.65) pushes the prediction to the left.
- You may wonder how we know the average values of the predictors. Remember the SHAP model is built on the training data set. The means of the variables are:
What is the result for the 2nd observation in Table B look like?Let’s do
How about the 3rd observation in Table B? Let’s do
Just to do one more before you become bored. The 4th observation in Table B is this: