Outline
- Introduction
- What are SHAP values?
- Walkthrough example
- How does Mage explain models?
- Conclusion
Introduction
Machine learning models are usually seen as a “black box.” It takes some features as input and produces some predictions as output. The common questions after model training are:
- How do different features affect the prediction results?
- What are the top features that influence the prediction results?
- The model performance metrics look great, but should I trust the results?
Thus, model explainability plays an important role in the machine learning world. Model insights are useful for the following reasons:
- Debugging
- Informing feature engineering
- Directing future data collection
- Informing human decision-making
- Building trust
We talked about the metrics used for model evaluation in a previous article “The definitive guide to Accuracy, Precision, and Recall for product developers.” Those metrics help you understand the overall performance of a model. However, we need more insights into how different features impact the model’s prediction.
There are multiple techniques to explain the models. In this article, we’ll introduce SHAP values, which is one of the most popular model explanation techniques. We’ll also walk through an example to show how to use SHAP values to get insights.
What are SHAP values?
SHAP stands for “SHapley Additive exPlanations.” Shapley values are a widely used approach from cooperative game theory. The essence of Shapley value is to measure the contributions to the final outcome from each player separately among the coalition, while preserving the sum of contributions being equal to the final outcome.
When using SHAP values in model explanation, we can measure the input features’ contribution to individual predictions. We won’t be covering the complex formulas to calculate SHAP values in this article, but we’ll show how to use the SHAP Python library to easily calculate SHAP values.
There are some other techniques used to explain models like permutation importance and partial dependence plots. Here are some benefits of using SHAP values over other techniques:
- Global interpretability: SHAP values not only show feature importance but also show whether the feature has a positive or negative impact on predictions.
- Local interpretability: We can calculate SHAP values for each individual prediction and know how the features contribute to that single prediction. Other techniques only show aggregated results over the whole dataset.
- SHAP values can be used to explain a large variety of models including linear models (e.g. linear regression), tree-based models (e.g. XGBoost) and neural networks, while other techniques can only be used to explain limited model types.
Walkthrough example
We’ll walk through an example to explain how SHAP values work in practice. If you’ve followed Mage’s blog articles, you might’ve already read “Build your first machine learning model.” We’ll use the same dataset titanic_survival to demonstrate how SHAP values work.
We use XGBoost to train the model to predict survival. “Sex”, “Pclass”, “Fare”, and “Age” features are used in the model training. “Survived” is the label feature with values 0 and 1.
We use this SHAP Python library to calculate SHAP values and plot charts. We select TreeExplainer here since XGBoost is a tree-based model.
import shap
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
The shap_values
is a 2D array. Each row belongs to a single prediction made by the model. Each column represents a feature used in the model. Each SHAP value represents how much this feature contributes to the output of this row’s prediction.
Positive SHAP value means positive impact on prediction, leading the model to predict 1(e.g. Passenger survived the Titanic). Negative SHAP value means negative impact, leading the model to predict 0 (e.g. passenger didn’t survive the Titanic).
Feature importance
We can use the summary_plot method with plot_type “bar” to plot the feature importance.
shap.summary_plot(shap_values, X, plot_type='bar')
The features are ordered by how much they influenced the model’s prediction. The x-axis stands for the average of the absolute SHAP value of each feature. For this example, “Sex” is the most important feature, followed by “Pclass”, “Fare”, and “Age”.
Directionality impact
With the same summary_plot method, we can plot dot charts to visualize the directionality impact of the features.
shap.summary_plot(shap_values, X)
In this chart, the x-axis stands for SHAP value, and the y-axis has all the features. Each point on the chart is one SHAP value for a prediction and feature. Red color means higher value of a feature. Blue means lower value of a feature. We can get the general sense of features’ directionality impact based on the distribution of the red and blue dots.
In the chart above, we can conclude the following insights:
- Higher value of “Sex” (male) leads to lower chance of survive. * Lower value of “Sex” (female) leads to higher chance of survival.
- Higher value of “Pclass” leads to lower chance to survive as well. Lower value of “Pclass” leads to higher chance of survival.
- Lower value of “Fare” leads to lower chance to survive.
- Higher value of “Age” leads to lower chance to survive.
By comparing these insights with general understanding of the problem, we can trust that the model is intuitive and making the right decisions. For example, the model will more likely predict survival if the passenger had a Sex of female. This is intuitive because women and children were put on lifeboats before men were allowed.
Individual predictions
Besides seeing the overall trend of feature impact, we can call the force_plot method to visualize how features contribute to individual predictions.
shap.force_plot(
explainer.expected_value,
shap_values[idx, :],
X.iloc[idx, :],
)
In this case, feature “Fare” and “Sex” have a positive impact on the prediction, while “Pclass” and “Age” have a negative impact on the prediction.
How does Mage explain models?
After the model is trained, Mage provides an Analyzer page to show the model performance and insights. In addition to model performance metrics (precision, recall, accuracy, etc), we leverage SHAP values to show features that have the most impact on model output and how those features impact the model output.
“Top Influencers” chart is similar to the summary bar chart from the shap
library. We calculate the average of the absolute SHAP values for each feature and use it to show which features were the most important when making a prediction.
We also translate SHAP values to this “Directionality” chart to show how features impact the output in a more intuitive way. Green plus icon means higher feature value and red minus icon means lower feature value. Horizontally, the right side means positive impact on the prediction and the left side means negative impact. We also summarize all the insights into text to make it easier to understand.
For the Titanic example, we can easily see when sex is female, pclass is smaller, fare is higher, or younger age, there is a higher chance to survive. These insights align with our understanding of what actually happened during the Titanic incident in real life.
Conclusion
Model explainability is an important topic in machine learning. SHAP values help you understand the model at row and feature level. The SHAP Python package is a useful library to calculate SHAP values, visualize the feature importance, and directionality impact using multiple charts.
Resources
- Original paper for SHAP: https://arxiv.org/pdf/1705.07874.pdf
- Documentation for 'shap' library : https://shap.readthedocs.io/en/latest/index.html
Top comments (0)