Note
Click here to download the full example code
01. Basic example
Out:
Kernel type: <class 'shap.explainers._tree.Tree'>
.values =
array([[ 0.01, 1.37, -4.46],
[-0.04, -0.05, -4.71],
[-0.04, -0.67, -4.25],
...,
[ 0.35, 0.36, 2.19],
[-0.03, -0.06, -4.72],
[-0.04, -0.67, -4.25]])
.base_values =
array([0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
0.35, 0.35, 0.35, 0.35, 0.35])
.data =
array([[ 17.99, 10.38, 122.8 ],
[ 20.57, 17.77, 132.9 ],
[ 19.69, 21.25, 130. ],
...,
[ 12.47, 17.31, 80.45],
[ 18.49, 17.52, 121.3 ],
[ 20.59, 21.24, 137.8 ]])
shap_values (shape): (500, 3)
<IPython.core.display.HTML object>
[[-0.09 -0.58 -4.2 ]
[-0.57 0.93 3.57]
[ 0.33 2.81 -2.97]
...
[ 0.78 -0.04 -2.36]
[ 0.25 -0.04 0.63]
[ 0.22 1.57 -3.67]]
None
'\nprint(sv)\n#sns.swarmplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#sns.stripplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#plt.show()\nimport sys\nsys.exit()\n#sns.swarmplot(x=)\n\nimport sys\nsys.exit()\n\n#html = f"<head>{shap.getjs()}</head><body>"\n# Bee swarm\n# .. note: unexpected algorithm matplotlib!\n# .. note: does not return an object!\nplot_bee = shap.plots.beeswarm(shap_values, show=False)\n\n# Sow\nprint("\nBEE")\nprint(plot_bee)\n\n#print(f)\n# Waterfall\n# .. note: not working!\n#shap.plots.waterfall(shap_values[0], max_display=14)\n\n# Force plot\n# .. note: not working!\nplot_force = shap.plots.force(explainer.expected_value,\n explainer.shap_values(X_train), X_train,\n matplotlib=False, show=False)\n\n# Show\nprint("\nFORCE:")\nprint(plot_force)\nprint(plot_force.html())\nprint(shap.save_html(\'e.html\', plot_force))\n'
8 # Generic
9 import numpy as np
10 import pandas as pd
11 import matplotlib.pyplot as plt
12
13 # Sklearn
14 from sklearn.model_selection import train_test_split
15 from sklearn.datasets import load_iris
16 from sklearn.datasets import load_breast_cancer
17 from sklearn.naive_bayes import GaussianNB
18 from sklearn.linear_model import LogisticRegression
19 from sklearn.tree import DecisionTreeClassifier
20 from sklearn.ensemble import RandomForestClassifier
21
22 # Xgboost
23 from xgboost import XGBClassifier
24
25 # ----------------------------------------
26 # Load data
27 # ----------------------------------------
28 # Seed
29 seed = 0
30
31 # Load dataset
32 bunch = load_iris()
33 bunch = load_breast_cancer()
34 features = list(bunch['feature_names'])
35
36 # Create DataFrame
37 data = pd.DataFrame(data=np.c_[bunch['data'], bunch['target']],
38 columns=features + ['target'])
39
40 # Create X, y
41 X = data[bunch['feature_names']]
42 y = data['target']
43
44 # Filter
45 X = X.iloc[:500, :3]
46 y = y.iloc[:500]
47
48
49 # Split dataset
50 X_train, X_test, y_train, y_test = \
51 train_test_split(X, y, random_state=seed)
52
53
54 # ----------------------------------------
55 # Classifiers
56 # ----------------------------------------
57 # Train classifier
58 gnb = GaussianNB()
59 llr = LogisticRegression()
60 dtc = DecisionTreeClassifier(random_state=seed)
61 rfc = RandomForestClassifier(random_state=seed)
62 xgb = XGBClassifier(
63 min_child_weight=0.005,
64 eta= 0.05, gamma= 0.2,
65 max_depth= 4,
66 n_estimators= 100)
67
68 # Select one
69 clf = xgb
70
71 # Fit
72 clf.fit(X_train, y_train)
73
74 # ----------------------------------------
75 # Find shap values
76 # ----------------------------------------
77 # Import
78 import shap
79
80 """
81 # Create shap explainer
82 if isinstance(clf,
83 (DecisionTreeClassifier,
84 RandomForestClassifier,
85 XGBClassifier)):
86 # Set Tree explainer
87 explainer = shap.TreeExplainer(clf)
88 elif isinstance(clf, int):
89 # Set NN explainer
90 explainer = shap.DeepExplainer(clf)
91 else:
92 # Set generic kernel explainer
93 explainer = shap.KernelExplainer(clf.predict_proba, X_train)
94 """
95
96 # Get generic explainer
97 explainer = shap.Explainer(clf, X_train)
98
99 # Show kernel type
100 print("\nKernel type: %s" % type(explainer))
101
102 # Get shap values
103 shap_values = explainer(X)
104
105 print(shap_values)
106
107 # For interactions!!
108 # https://github.com/slundberg/shap/issues/501
109
110 # Get shap values
111 #shap_values = \
112 # explainer.shap_values(X_train)
113 #shap_interaction_values = \
114 # explainer.shap_interaction_values(X_train)
115
116 # Show information
117 print("shap_values (shape): %s" % \
118 str(shap_values.shape))
119 #print("shap_values_interaction (shape): %s" % \
120 # str(shap_interaction_values.shape))
121
122
123 # ----------------------------------------
124 # Visualize
125 # ----------------------------------------
126 # Initialise
127 shap.initjs()
128
129 """
130 # Dependence plot
131 shap.dependence_plot(0, shap_values,
132 X_train, interaction_index=None, dot_size=5,
133 alpha=0.5, color='#3F75BC', show=False)
134 plt.tight_layout()
135 """
136
137 print(explainer.shap_values(X_train))
138
139 # Summary plot
140 plot_summary = shap.summary_plot( \
141 explainer.shap_values(X_train),
142 X_train, cmap='viridis',
143 show=False)
144
145 plt.tight_layout()
146 plt.show()
147
148 print(plot_summary)
149
150
151 import seaborn as sns
152 sv = explainer.shap_values(X_train)
153 sv = pd.DataFrame(sv, columns=X.columns)
154 sv = sv.stack().reset_index()
155 sv['val'] = X_train.stack().reset_index()[0]
156
157 #import plotly.express as px
158
159 #f = px.strip(data_frame=sv, x=0, y='level_1', color='val')
160 #f.show()
161
162 """
163 print(sv)
164 #sns.swarmplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis')
165 #sns.stripplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis')
166 #plt.show()
167 import sys
168 sys.exit()
169 #sns.swarmplot(x=)
170
171 import sys
172 sys.exit()
173
174 #html = f"<head>{shap.getjs()}</head><body>"
175 # Bee swarm
176 # .. note: unexpected algorithm matplotlib!
177 # .. note: does not return an object!
178 plot_bee = shap.plots.beeswarm(shap_values, show=False)
179
180 # Sow
181 print("\nBEE")
182 print(plot_bee)
183
184 #print(f)
185 # Waterfall
186 # .. note: not working!
187 #shap.plots.waterfall(shap_values[0], max_display=14)
188
189 # Force plot
190 # .. note: not working!
191 plot_force = shap.plots.force(explainer.expected_value,
192 explainer.shap_values(X_train), X_train,
193 matplotlib=False, show=False)
194
195 # Show
196 print("\nFORCE:")
197 print(plot_force)
198 print(plot_force.html())
199 print(shap.save_html('e.html', plot_force))
200 """
Total running time of the script: ( 0 minutes 2.068 seconds)